summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTyler Seip <Tyler.Seip@mongodb.com>2021-10-29 20:44:40 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-10-29 21:27:26 +0000
commit786482da93c3e5e58b1c690cb060f00c60864f69 (patch)
tree9d63c02f889251d2de875e71713f6d53fa2b2040
parent6ed9f24dc8470f6f7010f7b5cf6dba8db7c9e376 (diff)
downloadmongo-786482da93c3e5e58b1c690cb060f00c60864f69.tar.gz
SERVER-60677: Implement parser for Proxy Protocol V1 and V2 headers
-rw-r--r--src/mongo/transport/SConscript2
-rw-r--r--src/mongo/transport/proxy_protocol_header_parser.cpp453
-rw-r--r--src/mongo/transport/proxy_protocol_header_parser.h105
-rw-r--r--src/mongo/transport/proxy_protocol_header_parser_test.cpp573
4 files changed, 1133 insertions, 0 deletions
diff --git a/src/mongo/transport/SConscript b/src/mongo/transport/SConscript
index b166e728719..295bf0f5680 100644
--- a/src/mongo/transport/SConscript
+++ b/src/mongo/transport/SConscript
@@ -55,6 +55,7 @@ tlEnv.Library(
'transport_layer_asio.cpp',
'asio_utils.cpp',
'session_asio.cpp',
+ 'proxy_protocol_header_parser.cpp',
'transport_options.idl',
],
LIBDEPS=[
@@ -186,6 +187,7 @@ tlEnv.CppUnitTest(
'service_executor_test.cpp',
'max_conns_override_test.cpp',
'service_state_machine_test.cpp',
+ 'proxy_protocol_header_parser_test.cpp',
],
LIBDEPS=[
'$BUILD_DIR/mongo/base',
diff --git a/src/mongo/transport/proxy_protocol_header_parser.cpp b/src/mongo/transport/proxy_protocol_header_parser.cpp
new file mode 100644
index 00000000000..2a0b672f924
--- /dev/null
+++ b/src/mongo/transport/proxy_protocol_header_parser.cpp
@@ -0,0 +1,453 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kNetwork
+
+#include "mongo/transport/proxy_protocol_header_parser.h"
+
+#include <boost/optional.hpp>
+#include <cstring>
+#include <exception>
+#include <fmt/format.h>
+
+#ifndef _WIN32
+#include <netinet/in.h>
+#include <sys/un.h>
+#endif
+
+#include "mongo/base/parse_number.h"
+#include "mongo/base/string_data.h"
+#include "mongo/logv2/log.h"
+#include "mongo/platform/endian.h"
+#include "mongo/util/assert_util.h"
+
+namespace mongo::transport {
+
+using namespace fmt::literals;
+
+namespace {
+StringData parseToken(StringData& s, const char c) {
+ size_t pos = s.find(c);
+ uassert(ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string malformed: {}"_format(s),
+ pos != std::string::npos);
+ StringData result = s.substr(0, pos);
+ s = s.substr(pos + 1);
+ return result;
+}
+} // namespace
+
+namespace proxy_protocol_details {
+
+void validateIpv4Address(StringData addr) {
+ StringData buffer = addr;
+ const NumberParser octetParser =
+ NumberParser().skipWhitespace(false).base(10).allowTrailingText(false);
+ try {
+ for (size_t i = 0; i < 4; ++i) {
+ unsigned octet = 0;
+ if (i == 3) {
+ uassertStatusOK(octetParser(buffer, &octet));
+ } else {
+ uassertStatusOK(octetParser(parseToken(buffer, '.'), &octet));
+ }
+ uassert(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string specified malformed IPv4 address: {}"_format(
+ addr),
+ octet <= 255);
+ }
+ } catch (const ExceptionFor<ErrorCodes::FailedToParse>&) {
+ uasserted(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string specified malformed IPv4 address: {}"_format(
+ addr));
+ }
+}
+
+void validateIpv6Address(StringData addr) {
+ static constexpr StringData doubleColon = "::"_sd;
+
+ auto validateHexadectets = [](StringData buffer) -> size_t {
+ auto validateHexadectet = [](StringData hexadectet) {
+ const NumberParser hexadectetParser =
+ NumberParser().skipWhitespace(false).base(16).allowTrailingText(false);
+ unsigned value = 0;
+ uassertStatusOK(hexadectetParser(hexadectet, &value));
+ uassert(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string contains malformed IPv6 hexadectet: {}"_format(
+ hexadectet),
+ hexadectet.size() == 4 && value >= 0);
+ };
+
+ if (buffer.empty())
+ return 0;
+
+ uassert(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string contains malformed IPv6 hexadectet: {}"_format(
+ buffer),
+ buffer.find(doubleColon) == std::string::npos);
+
+ size_t numHexadectets = 0;
+ while (!buffer.empty()) {
+ if (const size_t pos = buffer.find(':'); pos != std::string::npos) {
+ validateHexadectet(buffer.substr(0, pos));
+ ++numHexadectets;
+ buffer = buffer.substr(pos + 1);
+ } else {
+ validateHexadectet(buffer);
+ return numHexadectets + 1;
+ }
+ }
+ uasserted(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string contains malformed IPv6 hexadectet: {}"_format(
+ buffer));
+ };
+
+ // There can be at most one double colon in our address. Split on the first
+ // one and validate neither half has another implicitly.
+ try {
+ if (const auto pos = addr.find(doubleColon); pos != std::string::npos) {
+ const size_t numHexadectets = validateHexadectets(addr.substr(0, pos)) +
+ validateHexadectets(addr.substr(pos + doubleColon.size()));
+ uassert(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string specified malformed IPv6 address: {}"_format(
+ addr),
+ numHexadectets < 8);
+ } else {
+ const size_t numHexadectets = validateHexadectets(addr);
+ uassert(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string specified malformed IPv6 address: {}"_format(
+ addr),
+ numHexadectets == 8);
+ }
+ } catch (const ExceptionFor<ErrorCodes::FailedToParse>&) {
+ uasserted(
+ ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string specified malformed IPv6 address: {}"_format(
+ addr));
+ }
+}
+
+} // namespace proxy_protocol_details
+
+namespace {
+
+// Interprets the first sizeof(T) bytes of data as a T and returns it, advancing the data cursor
+// by the same amount. Does not account for endianness of the data buffer.
+template <typename T>
+T extract(StringData& data) {
+ MONGO_STATIC_ASSERT(std::is_trivially_copyable_v<T>);
+ static constexpr size_t numBytes = sizeof(T);
+ if (data.size() < numBytes) {
+ throw std::out_of_range("Not enough space to extract object of size {}"_format(numBytes));
+ }
+
+ T result;
+ memcpy(&result, data.rawData(), numBytes);
+ data = data.substr(numBytes);
+ return result;
+}
+
+constexpr StringData kV1Start = "PROXY"_sd;
+
+bool parseV1Buffer(StringData& buffer, boost::optional<ProxiedEndpoints>& endpoints) {
+ buffer = buffer.substr(kV1Start.size());
+ if (buffer.empty())
+ return false;
+
+ // Scan the buffer for a newline and prepare an output buffer which begins just past
+ // the line.
+ static constexpr StringData crlf = "\x0D\x0A"_sd;
+ const auto crlfPos = buffer.find(crlf);
+
+ static constexpr size_t kMaximumV1HeaderSize = 107;
+ static constexpr size_t kMaximumV1InetLineSize = kMaximumV1HeaderSize - kV1Start.size();
+ if (crlfPos == std::string::npos) {
+ // If we couldn't find a newline sequence, then fail if there cannot be enough room
+ // for one to appear in the future.
+ uassert(ErrorCodes::FailedToParse,
+ "No terminating newline found in Proxy Protocol header V1: {}"_format(buffer),
+ buffer.size() <= kMaximumV1InetLineSize);
+ return false;
+ } else {
+ // If we could, then fail if the sequence doesn't occur within the maximum line length.
+ uassert(ErrorCodes::FailedToParse,
+ "No terminating newline found in Proxy Protocol header V1: {}"_format(buffer),
+ crlfPos + crlf.size() <= kMaximumV1InetLineSize);
+ }
+
+ // Prepare a result buffer pointing to just after the crlf sequence.
+ const auto resultBuffer = buffer.substr(crlfPos + crlf.size());
+
+ static constexpr StringData kTcp4Prefix = " TCP4 "_sd;
+ static constexpr StringData kTcp6Prefix = " TCP6 "_sd;
+ int aFamily = AF_UNSPEC;
+ if (buffer.startsWith(kTcp4Prefix)) {
+ aFamily = AF_INET;
+ buffer = buffer.substr(kTcp4Prefix.size());
+ } else if (buffer.startsWith(kTcp6Prefix)) {
+ aFamily = AF_INET6;
+ buffer = buffer.substr(kTcp6Prefix.size());
+ } else if (buffer.startsWith(" UNKNOWN"_sd)) {
+ buffer = resultBuffer;
+ endpoints = {};
+ return true;
+ } else {
+ uasserted(ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string malformed: {}"_format(buffer));
+ }
+
+ // The remainder of the string should now tokenize into four substrings:
+ // srcAddr dstAddr srcPort dstPort
+ const StringData srcAddr = parseToken(buffer, ' ');
+ const StringData dstAddr = parseToken(buffer, ' ');
+
+ invariant(aFamily == AF_INET || aFamily == AF_INET6);
+ if (aFamily == AF_INET) {
+ proxy_protocol_details::validateIpv4Address(srcAddr);
+ proxy_protocol_details::validateIpv4Address(dstAddr);
+ } else {
+ proxy_protocol_details::validateIpv6Address(srcAddr);
+ proxy_protocol_details::validateIpv6Address(dstAddr);
+ }
+
+ const StringData srcPortStr = parseToken(buffer, ' ');
+ const StringData dstPortStr = parseToken(buffer, '\r');
+
+ const NumberParser portParser =
+ NumberParser().skipWhitespace(false).base(10).allowTrailingText(false);
+ unsigned srcPort, dstPort = 0;
+ uassertStatusOK(portParser(srcPortStr, &srcPort));
+ uassertStatusOK(portParser(dstPortStr, &dstPort));
+
+ auto validatePort = [](int port) {
+ uassert(ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string specified invalid port: {}"_format(port),
+ port <= 65535);
+ };
+ validatePort(srcPort);
+ validatePort(dstPort);
+
+ buffer = resultBuffer;
+ try {
+ endpoints = ProxiedEndpoints{SockAddr::create(srcAddr, srcPort, aFamily),
+ SockAddr::create(dstAddr, dstPort, aFamily)};
+ return true;
+ } catch (const ExceptionFor<ErrorCodes::HostUnreachable>&) {
+ // SockAddr can throw on construction if the address passed in is malformed.
+ uasserted(ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 1 address string specified unreachable host: {}"_format(
+ buffer));
+ }
+}
+
+// Since this string contains a null, it's critical we use a literal here.
+constexpr StringData kV2Start = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"_sd;
+
+bool parseV2Buffer(StringData& buffer, boost::optional<ProxiedEndpoints>& endpoints) {
+ buffer = buffer.substr(kV2Start.size());
+ if (buffer.empty())
+ return false;
+
+ try {
+ const char protocolVersionAndCommandByte = extract<char>(buffer);
+
+ bool isLocal = false;
+ // The high nibble must be 2 (version 2) and the low nibble must be either 0 or 1 (local or
+ // remote).
+ if (protocolVersionAndCommandByte == '\x20') {
+ isLocal = true;
+ } else if (protocolVersionAndCommandByte != '\x21') {
+ uasserted(
+ ErrorCodes::FailedToParse,
+ "Invalid version or command byte given in Proxy Protocol header V2: {}"_format(
+ protocolVersionAndCommandByte));
+ }
+
+ if (buffer.empty())
+ return false;
+
+ const uint8_t transportProtocolAndAddressFamilyByte = extract<uint8_t>(buffer);
+
+ int aFamily = 0;
+ {
+ // Discard the family if this is a local connection.
+ uint8_t aFamilyByte = isLocal ? 0 : (transportProtocolAndAddressFamilyByte & 0xF0) >> 4;
+ switch (aFamilyByte) {
+ case 0:
+ aFamily = AF_UNSPEC;
+ break;
+ case 1:
+ aFamily = AF_INET;
+ break;
+ case 2:
+ aFamily = AF_INET6;
+ break;
+ case 3:
+ aFamily = AF_UNIX;
+ break;
+ default:
+ uasserted(ErrorCodes::FailedToParse,
+ "Invalid address family given in Proxy Protocol header V2: {}"_format(
+ aFamilyByte));
+ }
+ }
+
+ uint8_t protocol = (transportProtocolAndAddressFamilyByte & 0xF);
+ uassert(ErrorCodes::FailedToParse,
+ "Invalid protocol given in Proxy Protocol header V2: {}"_format(protocol),
+ protocol <= 0x2);
+
+ // If protocol is unspecified, we should also ignore address information.
+ if (protocol == 0) {
+ aFamily = AF_UNSPEC;
+ }
+
+ if (buffer.size() < sizeof(uint16_t))
+ return false;
+
+ const size_t length = endian::bigToNative(extract<uint16_t>(buffer));
+ if (buffer.size() < length)
+ return false;
+
+ // Prepare an output buffer that skips past the end of the header.
+ // We'll assign this to the buffer if we fully succeed in parsing the header.
+ const auto resultBuffer = buffer.substr(length);
+
+ switch (aFamily) {
+ case AF_UNSPEC:
+ break;
+ case AF_INET: {
+ // The proxy protocol allocates 12 bytes to represent a pair of IPv4 addresses
+ // along with their ports.
+ static constexpr size_t kIPv4ProxyProtocolSize = 12;
+ MONGO_STATIC_ASSERT(2 * (sizeof(in_addr) + sizeof(uint16_t)) ==
+ kIPv4ProxyProtocolSize);
+ uassert(ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 2 address string too short: {}"_format(buffer),
+ length >= kIPv4ProxyProtocolSize);
+ sockaddr_in src_addr{};
+ sockaddr_in dst_addr{};
+ src_addr.sin_family = dst_addr.sin_family = AF_INET;
+ // These are specified by the protocol to be in network byte order, which
+ // is what sin_addr/sin_port expect, so we copy them directly.
+ src_addr.sin_addr = extract<in_addr>(buffer);
+ dst_addr.sin_addr = extract<in_addr>(buffer);
+ src_addr.sin_port = extract<uint16_t>(buffer);
+ dst_addr.sin_port = extract<uint16_t>(buffer);
+ endpoints = ProxiedEndpoints{SockAddr((sockaddr*)&src_addr, sizeof(sockaddr_in)),
+ SockAddr((sockaddr*)&dst_addr, sizeof(sockaddr_in))};
+ break;
+ }
+ case AF_INET6: {
+ // The proxy protocol allocates 36 bytes to represent a pair of IPv6 addresses
+ // along with their ports.
+ static constexpr size_t kIPv6ProxyProtocolSize = 36;
+ MONGO_STATIC_ASSERT(2 * (sizeof(in6_addr) + sizeof(uint16_t)) ==
+ kIPv6ProxyProtocolSize);
+ uassert(ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 2 address string too short: {}"_format(buffer),
+ length >= kIPv6ProxyProtocolSize);
+ sockaddr_in6 src_addr{};
+ sockaddr_in6 dst_addr{};
+ src_addr.sin6_family = dst_addr.sin6_family = AF_INET6;
+ // These are specified by the protocol to be in network byte order, which
+ // is what sin_addr/sin_port expect, so we copy them directly.
+ src_addr.sin6_addr = extract<in6_addr>(buffer);
+ dst_addr.sin6_addr = extract<in6_addr>(buffer);
+ src_addr.sin6_port = extract<uint16_t>(buffer);
+ dst_addr.sin6_port = extract<uint16_t>(buffer);
+ endpoints = ProxiedEndpoints{SockAddr((sockaddr*)&src_addr, sizeof(sockaddr_in6)),
+ SockAddr((sockaddr*)&dst_addr, sizeof(sockaddr_in6))};
+ break;
+ }
+ case AF_UNIX: {
+ // The proxy protocol allocates 216 bytes to represent a pair of UNIX address,
+ // but we don't assert type sizes here because some platforms don't support
+ // UNIX addresses of this length - they are checked in parseSockAddrUn.
+ static constexpr size_t kUnixProxyProtocolSize = 216;
+ uassert(ErrorCodes::FailedToParse,
+ "Proxy Protocol Version 2 address string too short: {}"_format(buffer),
+ length >= kUnixProxyProtocolSize);
+ const auto src_addr = proxy_protocol_details::parseSockAddrUn(
+ buffer.substr(0, proxy_protocol_details::kMaxUnixPathLength));
+ const auto dst_addr = proxy_protocol_details::parseSockAddrUn(
+ buffer.substr(proxy_protocol_details::kMaxUnixPathLength,
+ proxy_protocol_details::kMaxUnixPathLength));
+
+ endpoints = ProxiedEndpoints{SockAddr((sockaddr*)&src_addr, sizeof(sockaddr_un)),
+ SockAddr((sockaddr*)&dst_addr, sizeof(sockaddr_un))};
+ break;
+ }
+ default:
+ MONGO_UNREACHABLE;
+ }
+ buffer = resultBuffer;
+ return true;
+ } catch (const std::out_of_range&) {
+ return false;
+ }
+}
+
+} // namespace
+
+boost::optional<ParserResults> parseProxyProtocolHeader(StringData buffer) {
+ // Check if the buffer presented is V1, V2, or neither.
+ const size_t originalBufferSize = buffer.size();
+
+ ParserResults results;
+ bool complete = false;
+ if (buffer.startsWith(kV1Start)) {
+ complete = parseV1Buffer(buffer, results.endpoints);
+ } else if (buffer.startsWith(kV2Start)) {
+ complete = parseV2Buffer(buffer, results.endpoints);
+ } else {
+ uassert(ErrorCodes::FailedToParse,
+ "Initial Proxy Protocol header bytes invalid: {}"
+ "; Make sure your proxy is configured to emit a Proxy "
+ "Protocol header"_format(buffer),
+ kV1Start.startsWith(buffer) || kV2Start.startsWith(buffer));
+ }
+
+ if (complete) {
+ results.bytesParsed = originalBufferSize - buffer.size();
+ return results;
+ } else {
+ return {};
+ }
+}
+
+
+} // namespace mongo::transport
diff --git a/src/mongo/transport/proxy_protocol_header_parser.h b/src/mongo/transport/proxy_protocol_header_parser.h
new file mode 100644
index 00000000000..f3a3e6126ab
--- /dev/null
+++ b/src/mongo/transport/proxy_protocol_header_parser.h
@@ -0,0 +1,105 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <algorithm>
+#include <boost/optional.hpp>
+#include <fmt/format.h>
+
+#ifndef _WIN32
+#include <sys/un.h>
+#endif
+
+#include "mongo/base/error_codes.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/net/sockaddr.h"
+
+namespace mongo::transport {
+
+/**
+ * Represents the true endpoints that a proxy using the Proxy Protocol is proxying for us.
+ */
+struct ProxiedEndpoints {
+ // The true origin of the connection, i.e. the IP address of the client behind the
+ // proxy.
+ SockAddr sourceAddress;
+ // The true destination of the connection, almost always the address that the proxy
+ // is listening on.
+ SockAddr destinationAddress;
+};
+
+/**
+ * Contains the results of parsing a Proxy Protocol header. bytesParsed contains the
+ * length of the parsed header, and endpoints contains any endpoint information that the
+ * header optionally contained.
+ */
+struct ParserResults {
+ // The endpoint metadata should be populated iff parsing is complete, the connection
+ // is marked as remote, and the connection is not marked as UNKNOWN.
+ boost::optional<ProxiedEndpoints> endpoints = {};
+ size_t bytesParsed = 0;
+};
+
+/**
+ * Parses a string potentially starting with a proxy protocol header (either V1 or V2).
+ * If the string begins with a partial but incomplete header, returns an empty optional;
+ * otherwise, returns a ParserResults with the results of the parse.
+ *
+ * Will throw eagerly on a malformed header.
+ */
+boost::optional<ParserResults> parseProxyProtocolHeader(StringData buffer);
+
+namespace proxy_protocol_details {
+// The maximum number of bytes ever needed by a proxy protocol header; represents
+// the minimum TCP MTU.
+static constexpr size_t kBytesToFetch = 536;
+static constexpr size_t kMaxUnixPathLength = 108;
+
+template <typename AddrUn = sockaddr_un>
+AddrUn parseSockAddrUn(StringData buffer) {
+ using namespace fmt::literals;
+
+ AddrUn addr{};
+ addr.sun_family = AF_UNIX;
+
+ StringData path = buffer.substr(0, buffer.find('\0'));
+ uassert(ErrorCodes::FailedToParse,
+ "Provided unix path longer than system supports: {}"_format(buffer),
+ path.size() < sizeof(AddrUn::sun_path));
+ std::copy(path.begin(), path.end(), addr.sun_path);
+ return addr;
+}
+
+void validateIpv4Address(StringData addr);
+void validateIpv6Address(StringData addr);
+
+} // namespace proxy_protocol_details
+
+} // namespace mongo::transport
diff --git a/src/mongo/transport/proxy_protocol_header_parser_test.cpp b/src/mongo/transport/proxy_protocol_header_parser_test.cpp
new file mode 100644
index 00000000000..8de76feb091
--- /dev/null
+++ b/src/mongo/transport/proxy_protocol_header_parser_test.cpp
@@ -0,0 +1,573 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/transport/proxy_protocol_header_parser.h"
+
+#include "mongo/unittest/assert_that.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/shared_buffer.h"
+
+namespace mongo::transport {
+namespace {
+
+using namespace unittest::match;
+using namespace fmt::literals;
+
+template <typename MSrc, typename MDst>
+class ProxiedEndpointsAre : public Matcher {
+public:
+ explicit ProxiedEndpointsAre(MSrc&& s, MDst&& d) : _src(std::move(s)), _dst(std::move(d)) {}
+
+ std::string describe() const {
+ return "ProxiedEndpointsAre({}, {})"_format(_src.describe(), _dst.describe());
+ }
+
+ MatchResult match(const ProxiedEndpoints& e) const {
+ return StructuredBindingsAre<MSrc, MDst>(_src, _dst).match(e);
+ }
+
+private:
+ MSrc _src;
+ MDst _dst;
+};
+
+ParserResults parseAllPrefixes(StringData s) {
+ boost::optional<ParserResults> results;
+ for (size_t len = 0; len <= s.size(); ++len) {
+ StringData sub = s.substr(0, len);
+ results = parseProxyProtocolHeader(sub);
+ if (len < s.size()) {
+ ASSERT_FALSE(results) << "size={}, sub={}"_format(len, sub);
+ }
+ }
+ ASSERT_TRUE(results);
+ return *results;
+}
+
+void parseStringExpectFailure(StringData s, std::string regex) {
+ try {
+ parseAllPrefixes(s);
+ FAIL("Expected to throw");
+ } catch (const DBException& ex) {
+ ASSERT_THAT(ex.toStatus(), StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex(regex)));
+ }
+}
+
+boost::optional<ProxiedEndpoints> parseStringExpectSuccess(StringData s) {
+ const ParserResults results = parseAllPrefixes(s);
+ ASSERT_THAT(results.bytesParsed, Eq(s.size()));
+
+ // Also test that adding garbage to the end doesn't increase the bytesParsed amount.
+ const boost::optional<ParserResults> possibleResultsWithGarbage =
+ parseProxyProtocolHeader(s + "garbage");
+ ASSERT_TRUE(possibleResultsWithGarbage);
+ const ParserResults resultsWithGarbage = *possibleResultsWithGarbage;
+ ASSERT_THAT(resultsWithGarbage.bytesParsed, Eq(s.size()));
+ if (results.endpoints) {
+ ASSERT_THAT(*results.endpoints,
+ ProxiedEndpointsAre(Eq(resultsWithGarbage.endpoints->sourceAddress),
+ Eq(resultsWithGarbage.endpoints->destinationAddress)));
+ } else {
+ ASSERT_FALSE(resultsWithGarbage.endpoints);
+ }
+
+ return resultsWithGarbage.endpoints;
+}
+
+TEST(ProxyProtocolHeaderParser, MalformedIpv4Addresses) {
+ StringData testCases[] = {"1",
+ "1.1",
+ "1.1.1",
+ "1.1.1.1.1",
+ "1.1.1.1.",
+ ".1.1.1.1",
+ "1234.1.1.1",
+ "1.1234.1.1",
+ "1.1.1234.1",
+ "1.1.1.1234",
+ "1.1.1.a",
+ "1.1.1.256",
+ "256.1.1.1",
+ "1.1..1.1",
+ "-0.1.1.1",
+ "-1.1.1.1",
+ ""};
+
+ for (const auto& testCase : testCases) {
+ try {
+ proxy_protocol_details::validateIpv4Address(testCase);
+ FAIL("Expected to throw");
+ } catch (const DBException& ex) {
+ ASSERT_THAT(ex.toStatus(),
+ StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("malformed")));
+ }
+ }
+}
+
+TEST(ProxyProtocolHeaderParser, WellFormedIpv4Addresses) {
+ StringData testCases[] = {
+ "1.1.1.1", "0.0.0.0", "255.255.255.255", "0.255.0.255", "127.0.1.1", "1.12.123.0"};
+
+ for (const auto& testCase : testCases) {
+ proxy_protocol_details::validateIpv4Address(testCase);
+ }
+}
+
+TEST(ProxyProtocolHeaderParser, MalformedIpv6Addresses) {
+ StringData testCases[] = {"0000",
+ "0000:0000",
+ "0000:0000:0000",
+ "0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000:",
+ ":0000:0000:0000:0000:0000:0000:0000",
+ "00000:0000:0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000:00000",
+ "0000:-0000:0000:0000:0000:0000:0000:0000",
+ "0000:-000:0000:0000:0000:0000:0000:0000",
+ "000g:0000:0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000:000g",
+ "0000::0000:0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000::0000",
+ "0000:0000:0000:0000:0000:0000:0000:0000::",
+ "::0000:0000:0000:0000:0000:0000:0000:0000",
+ "::0000::",
+ "0000::0000::0000:0000:0000:0000:0000",
+ "0000::0000::0000:0000:0000:0000:0000:0000",
+ "::0000:",
+ ":0000::",
+ ":::",
+ ""};
+
+ for (const auto& testCase : testCases) {
+ try {
+ proxy_protocol_details::validateIpv6Address(testCase);
+ FAIL("Expected to throw");
+ } catch (const DBException& ex) {
+ ASSERT_THAT(ex.toStatus(),
+ StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("malformed")));
+ }
+ }
+}
+
+TEST(ProxyProtocolHeaderParser, WellFormedIpv6Addresses) {
+ StringData testCases[] = {"::",
+ "::0000",
+ "::0000:0000",
+ "::0000:0000:0000",
+ "::0000:0000:0000:0000",
+ "::0000:0000:0000:0000:0000",
+ "::0000:0000:0000:0000:0000:0000",
+ "::0000:0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000::",
+ "0000:0000:0000:0000:0000:0000::",
+ "0000:0000:0000:0000:0000::",
+ "0000:0000:0000:0000::",
+ "0000:0000:0000::",
+ "0000:0000::",
+ "0000::",
+ "0000::0000",
+ "0000::0000:0000",
+ "0000::0000:0000:0000",
+ "0000::0000:0000:0000:0000",
+ "0000::0000:0000:0000:0000:0000",
+ "0000::0000:0000:0000:0000:0000:0000",
+ "0000:0000:0000::0000:0000:0000",
+ "0000:0000:0000::0000:0000",
+ "0000:0000:0000:0000:0000:0000:0000:0000",
+ "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
+ "ffff::",
+ "0123:4567:89ab:cdef::",
+ "::0123:4567:89ab:cdef",
+ "0123:4567::89ab:cdef"};
+
+ for (const auto& testCase : testCases) {
+ proxy_protocol_details::validateIpv6Address(testCase);
+ }
+}
+
+TEST(ProxyProtocolHeaderParser, MalformedV1Headers) {
+ std::pair<std::string, std::string> testCases[] = {
+ {"PORXY ", "header bytes invalid"},
+
+ {"PROXY " + std::string(200, '1'), "No terminating newline"},
+
+ // Even if there is a terminating newline, it has to happen before the longest possible
+ // header length is seen.
+ {"PROXY UNKNOWN " + std::string(92, '1') + "\r\n", "No terminating newline"},
+
+ {"PROXY " + std::string(50, '\r') + "1" + "\r\r\n", "address string malformed"},
+ {"PROXY TCP4 \r\n", "address string malformed"},
+ {"PROXY TCP4 1.1.1.1\r\n", "address string malformed"},
+ {"PROXY TCP4 1.1.1.1 1.1.1.1 10\r\n", "address string malformed"},
+
+ {"PROXY TCP4 12800000000 28 10 10\r\n", "malformed IPv4"},
+ {"PROXY TCP4 128 28000000000000 10 10\r\n", "malformed IPv4"},
+ {"PROXY TCP4 1.1.1.1 notanip 10 300\r\n", "malformed IPv4"},
+ {"PROXY TCP4 a:b:c:d 20 10 300\r\n", "malformed IPv4"},
+ {"PROXY TCP4 1.1.1.1 1.1.1.1 -10 300\r\n", "Negative"},
+ {"PROXY TCP4 1.1.1.1 1.1.1.1 10 -300\r\n", "Negative"},
+ {"PROXY TCP4 1.1.1.1 2.2.2.2 notaport 10\r\n", "Did not consume"},
+ {"PROXY TCP4 1.1.1.1 2.2.2.2 10 20garbage\r\n", "Did not consume"},
+
+ // Check TCP6
+ {"PROXY TCP6 \r\n", "address string malformed"},
+ {"PROXY TCP6 ::\r\n", "address string malformed"},
+ {"PROXY TCP6 :: :: 10\r\n", "address string malformed"},
+
+ {"PROXY TCP6 1.1.1.1 2.2.2.2 10 10\r\n", "malformed IPv6"},
+ {"PROXY TCP6 :: ::000g 10 10\r\n", "malformed IPv6"},
+ {"PROXY TCP6 :: :: -10 10\r\n", "Negative"},
+ {"PROXY TCP6 :: :: 10 -10\r\n", "Negative"},
+ {"PROXY TCP6 :: notanip 10 300\r\n", "malformed IPv6"},
+ {"PROXY TCP6 :: :: notaport 10\r\n", "Did not consume"},
+ {"PROXY TCP6 :: :: 10 20garbage\r\n", "Did not consume"}};
+
+ for (const auto& testCase : testCases) {
+ parseStringExpectFailure(testCase.first, testCase.second);
+ }
+}
+
+TEST(ProxyProtocolHeaderParser, WellFormedV1Headers) {
+ ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP4 1.1.1.1 2.2.2.2 10 300\r\n"),
+ ProxiedEndpointsAre(Eq(SockAddr::create("1.1.1.1", 10, AF_INET)),
+ Eq(SockAddr::create("2.2.2.2", 300, AF_INET))));
+
+ ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP4 0.0.0.128 0.0.1.44 1000 3000\r\n"),
+ ProxiedEndpointsAre(Eq(SockAddr::create("0.0.0.128", 1000, AF_INET)),
+ Eq(SockAddr::create("0.0.1.44", 3000, AF_INET))));
+
+ static constexpr StringData allFs = "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"_sd;
+ ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP6 {} "
+ "{} 10000 30000\r\n"_format(allFs, allFs)),
+ ProxiedEndpointsAre(Eq(SockAddr::create(allFs, 10000, AF_INET6)),
+ Eq(SockAddr::create(allFs, 30000, AF_INET6))));
+
+ ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP6 :: {} 1000 3000\r\n"_format(allFs)),
+ ProxiedEndpointsAre(Eq(SockAddr::create("::", 1000, AF_INET6)),
+ Eq(SockAddr::create(allFs, 3000, AF_INET6))));
+
+ ASSERT_THAT(*parseStringExpectSuccess("PROXY TCP6 2001:0db8:: 0064:ff9b::0000 1000 3000\r\n"),
+ ProxiedEndpointsAre(Eq(SockAddr::create("2001:db8::", 1000, AF_INET6)),
+ Eq(SockAddr::create("64:ff9b::0000", 3000, AF_INET6))));
+
+ // The shortest possible V1 header
+ ASSERT_FALSE(parseStringExpectSuccess("PROXY UNKNOWN\r\n"));
+ ASSERT_FALSE(
+ parseStringExpectSuccess("PROXY UNKNOWN 2001:db8:: 64:ff9b::0.0.0.0 1000 3000\r\n"));
+ ASSERT_FALSE(parseStringExpectSuccess("PROXY UNKNOWN hot garbage\r\n"));
+ // The longest possible V1 header
+ ASSERT_FALSE(
+ parseStringExpectSuccess("PROXY UNKNOWN {} {} 65535 65535\r\n"_format(allFs, allFs)));
+}
+
+struct TestV2Header {
+ std::string header;
+ std::string versionAndCommand;
+ std::string addressFamilyAndProtocol;
+ std::string length;
+ std::string firstAddr;
+ std::string secondAddr;
+ std::string metadata;
+
+ std::string toString() const {
+ return "{}{}{}{}{}{}{}"_format(header,
+ versionAndCommand,
+ addressFamilyAndProtocol,
+ length,
+ firstAddr,
+ secondAddr,
+ metadata);
+ }
+};
+
+TEST(ProxyProtocolHeaderParser, MalformedV2Headers) {
+ // These strings contain null characters in them, so we need string literals.
+ using namespace std::string_literals;
+
+ TestV2Header header;
+
+ // Specify an invalid header.
+ header.header = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x48\x54\x0A"s;
+ parseStringExpectFailure(header.toString(), "header bytes invalid");
+
+ // Correct the header but break the version.
+ header.header = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"s;
+ header.versionAndCommand = "\x30";
+ parseStringExpectFailure(header.toString(), "Invalid version");
+ header.versionAndCommand = "\x22";
+ parseStringExpectFailure(header.toString(), "Invalid version");
+ // Correct the version/command but break the address family.
+ header.versionAndCommand = "\x21";
+ header.addressFamilyAndProtocol = "\x40";
+ parseStringExpectFailure(header.toString(), "Invalid address");
+ header.addressFamilyAndProtocol = "\x23";
+ parseStringExpectFailure(header.toString(), "Invalid protocol");
+
+ // TCP4
+ // Set the length to 1.
+ header.addressFamilyAndProtocol = "\x11";
+ header.length = "\x00\x01"s;
+ header.firstAddr = std::string(1, '\0');
+ parseStringExpectFailure(header.toString(), "too short");
+ // Set to the longest non-valid length (11).
+ header.length = "\x00\x0B"s;
+ header.firstAddr = std::string(6, '\0');
+ header.secondAddr = std::string(5, '\0');
+ parseStringExpectFailure(header.toString(), "too short");
+
+ // TCP6
+ // Set the length to 1.
+ header.addressFamilyAndProtocol = "\x21";
+ header.length = "\x00\x01"s;
+ header.firstAddr = std::string(1, '\0');
+ header.secondAddr = "";
+ parseStringExpectFailure(header.toString(), "too short");
+ // Set to the longest non-valid length (35).
+ header.length = "\x00\x23"s;
+ header.firstAddr = std::string(18, '\0');
+ header.secondAddr = std::string(17, '\0');
+ parseStringExpectFailure(header.toString(), "too short");
+
+ // UNIX
+ // Set the length to 1.
+ header.addressFamilyAndProtocol = "\x31";
+ header.length = "\x00\x01"s;
+ header.firstAddr = std::string(1, '\0');
+ header.secondAddr = "";
+ parseStringExpectFailure(header.toString(), "too short");
+ // Set to the longest non-valid length (35).
+ header.length = "\x00\xD7"s;
+ header.firstAddr = std::string(108, '\0');
+ header.secondAddr = std::string(107, '\0');
+ parseStringExpectFailure(header.toString(), "too short");
+}
+
+template <typename SockAddrUn = sockaddr_un>
+std::pair<SockAddrUn, SockAddrUn> createTestSockAddrUn(std::string srcPath, std::string dstPath) {
+ std::pair<SockAddrUn, SockAddrUn> addrs;
+ addrs.first.sun_family = addrs.second.sun_family = AF_UNIX;
+
+ memcpy(addrs.first.sun_path,
+ srcPath.c_str(),
+ std::min(srcPath.size(), sizeof(SockAddrUn::sun_path)));
+ memcpy(addrs.second.sun_path,
+ dstPath.c_str(),
+ std::min(dstPath.size(), sizeof(SockAddrUn::sun_path)));
+
+ return addrs;
+}
+
+std::string createTestUnixPathString(StringData path) {
+ std::string out{path};
+ out.resize(proxy_protocol_details::kMaxUnixPathLength);
+ return out;
+}
+
+TEST(ProxyProtocolHeaderParser, WellFormedV2Headers) {
+ // These strings contain null characters in them, so we need string literals.
+ using namespace std::string_literals;
+
+ TestV2Header header;
+ // TCP4
+ header.header = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"s;
+ header.versionAndCommand = "\x21"s;
+ header.addressFamilyAndProtocol = "\x12"s;
+ header.length = "\x00\x0C"s;
+ header.firstAddr = "\x0C\x22\x38\x4e\xac\x10"s;
+ header.secondAddr = "\x00\x01\x04\xd2\x1f\x90"s;
+ ASSERT_THAT(*parseStringExpectSuccess(header.toString()),
+ ProxiedEndpointsAre(Eq(SockAddr::create("12.34.56.78", 1234, AF_INET)),
+ Eq(SockAddr::create("172.16.0.1", 8080, AF_INET))));
+
+ // We correctly ignore data at the end.
+ header.length = "\x00\x0E"s;
+ header.metadata = std::string(2, '\0');
+ ASSERT_THAT(*parseStringExpectSuccess(header.toString()),
+ ProxiedEndpointsAre(Eq(SockAddr::create("12.34.56.78", 1234, AF_INET)),
+ Eq(SockAddr::create("172.16.0.1", 8080, AF_INET))));
+
+ // If this is a local connection, return nothing.
+ header.versionAndCommand = "\x20"s;
+ ASSERT_FALSE(parseStringExpectSuccess(header.toString()));
+
+ // TCP6
+ header.versionAndCommand = "\x21"s;
+ header.addressFamilyAndProtocol = "\x21"s;
+ header.length = "\x00\x24"s;
+ header.firstAddr = "\x20\x1\xd\xb8\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x64"s;
+ header.secondAddr = "\xff\x9b\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x0\x4\xd2\x1f\x90"s;
+ header.metadata = "";
+ ASSERT_THAT(*parseStringExpectSuccess(header.toString()),
+ ProxiedEndpointsAre(Eq(SockAddr::create("2001:db8::", 1234, AF_INET6)),
+ Eq(SockAddr::create("64:ff9b::0.0.0.0", 8080, AF_INET6))));
+
+ // We correctly ignore data at the end.
+ header.length = "\x00\x55"s;
+ header.metadata = std::string(49, '\1');
+ ASSERT_THAT(*parseStringExpectSuccess(header.toString()),
+ ProxiedEndpointsAre(Eq(SockAddr::create("2001:db8::", 1234, AF_INET6)),
+ Eq(SockAddr::create("64:ff9b::0.0.0.0", 8080, AF_INET6))));
+
+ // If this is a local connection, return nothing.
+ header.versionAndCommand = "\x20"s;
+ ASSERT_FALSE(parseStringExpectSuccess(header.toString()));
+
+ // UNIX
+ header.versionAndCommand = "\x21"s;
+ header.addressFamilyAndProtocol = "\x31"s;
+ header.length = "\x00\xD8"s;
+ const std::string srcPath(sizeof(sockaddr_un::sun_path) / 2, '\1');
+ const std::string dstPath(sizeof(sockaddr_un::sun_path) - 1, '\2');
+ header.firstAddr = createTestUnixPathString(srcPath);
+ header.secondAddr = createTestUnixPathString(dstPath);
+ header.metadata = "";
+ const auto addrs = createTestSockAddrUn(srcPath, dstPath);
+ ASSERT_THAT(*parseStringExpectSuccess(header.toString()),
+ ProxiedEndpointsAre(Eq(SockAddr((sockaddr*)&addrs.first, sizeof(sockaddr_un))),
+ Eq(SockAddr((sockaddr*)&addrs.second, sizeof(sockaddr_un)))));
+
+ // We correctly ignore data at the end.
+ header.length = "\x01\x08";
+ header.metadata = std::string(48, '\0');
+ // Extraneous data at the end is correctly ingested and ignored.
+ ASSERT_THAT(*parseStringExpectSuccess(header.toString()),
+ ProxiedEndpointsAre(Eq(SockAddr((sockaddr*)&addrs.first, sizeof(sockaddr_un))),
+ Eq(SockAddr((sockaddr*)&addrs.second, sizeof(sockaddr_un)))));
+
+ // If this is a local connection, return nothing.
+ header.versionAndCommand = "\x20"s;
+ ASSERT_FALSE(parseStringExpectSuccess(header.toString()));
+
+ // The family is not parsed if the connection is local.
+ header.addressFamilyAndProtocol = "\xA2";
+ ASSERT_FALSE(parseStringExpectSuccess(header.toString()));
+}
+
+struct TestSockAddrUnLinux {
+ sa_family_t sun_family;
+ char sun_path[108];
+
+ friend bool operator==(const TestSockAddrUnLinux& a, const TestSockAddrUnLinux& b) {
+ return a.sun_family == b.sun_family && !memcmp(a.sun_path, b.sun_path, sizeof(sun_path));
+ }
+};
+
+TEST(ProxyProtocolHeaderParser, LinuxSockAddrUnParsing) {
+ // Test the parser against a Linux-like sockaddr_un
+ {
+ const std::string srcPath(sizeof(TestSockAddrUnLinux::sun_path) / 2, '\1');
+ const std::string dstPath(sizeof(TestSockAddrUnLinux::sun_path) - 1, '\2');
+ const auto addrs = createTestSockAddrUn<TestSockAddrUnLinux>(srcPath, dstPath);
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>(
+ createTestUnixPathString(srcPath)),
+ Eq(addrs.first));
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>(
+ createTestUnixPathString(dstPath)),
+ Eq(addrs.second));
+ }
+
+ {
+ const std::string srcPath(sizeof(TestSockAddrUnLinux::sun_path), '\0');
+ const std::string dstPath(1, '\0');
+ const auto addrs = createTestSockAddrUn<TestSockAddrUnLinux>(srcPath, dstPath);
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>(
+ createTestUnixPathString(srcPath)),
+ Eq(addrs.first));
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnLinux>(
+ createTestUnixPathString(dstPath)),
+ Eq(addrs.second));
+ }
+}
+
+struct TestSockAddrUnMac {
+ unsigned char sun_len;
+ sa_family_t sun_family;
+ char sun_path[104];
+
+ friend bool operator==(const TestSockAddrUnMac& a, const TestSockAddrUnMac& b) {
+ return a.sun_family == b.sun_family && a.sun_len == b.sun_len &&
+ !memcmp(a.sun_path, b.sun_path, a.sun_len);
+ }
+};
+
+TEST(ProxyProtocolHeaderParser, MacSockAddrUnParsing) {
+ // Test the parser against a Mac-like sockaddr_un
+ {
+ const std::string srcPath(sizeof(TestSockAddrUnMac::sun_path) / 2, '\1');
+ const std::string dstPath(sizeof(TestSockAddrUnMac::sun_path) - 1, '\2');
+ const auto addrs = createTestSockAddrUn<TestSockAddrUnMac>(srcPath, dstPath);
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>(
+ createTestUnixPathString(srcPath)),
+ Eq(addrs.first));
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>(
+ createTestUnixPathString(dstPath)),
+ Eq(addrs.second));
+ }
+
+ {
+ const std::string srcPath(1, '\1');
+ const std::string dstPath = "";
+ const auto addrs = createTestSockAddrUn<TestSockAddrUnMac>(srcPath, dstPath);
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>(
+ createTestUnixPathString(srcPath)),
+ Eq(addrs.first));
+ ASSERT_THAT(proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>(
+ createTestUnixPathString(dstPath)),
+ Eq(addrs.second));
+ }
+
+ try {
+ const std::string srcPath(proxy_protocol_details::kMaxUnixPathLength - 1, '\1');
+ const std::string dstPath(proxy_protocol_details::kMaxUnixPathLength - 1, '\2');
+ proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>(
+ createTestUnixPathString(srcPath) + createTestUnixPathString(dstPath));
+ FAIL("Expected to throw");
+ } catch (const DBException& ex) {
+ ASSERT_THAT(ex.toStatus(),
+ StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("longer than system")));
+ }
+
+ try {
+ const std::string srcPath(sizeof(TestSockAddrUnMac::sun_path), '\1');
+ const std::string dstPath(sizeof(TestSockAddrUnMac::sun_path), '\2');
+ proxy_protocol_details::parseSockAddrUn<TestSockAddrUnMac>(
+ createTestUnixPathString(srcPath) + createTestUnixPathString(dstPath));
+ FAIL("Expected to throw");
+ } catch (const DBException& ex) {
+ ASSERT_THAT(ex.toStatus(),
+ StatusIs(Eq(ErrorCodes::FailedToParse), ContainsRegex("longer than system")));
+ }
+}
+
+} // namespace
+} // namespace mongo::transport