diff options
author | Sara Golemon <sara.golemon@mongodb.com> | 2017-06-27 10:00:42 -0400 |
---|---|---|
committer | Sara Golemon <sara.golemon@mongodb.com> | 2017-07-06 16:36:35 -0400 |
commit | 4b222edf455a34667cfaf7b67e7f8dfdca42bd9c (patch) | |
tree | 5f468e532fa02650803ea20036c7b9a9161ce96f | |
parent | f0b95cc0c48fe242edbc9c7958f9df0a34813e78 (diff) | |
download | mongo-4b222edf455a34667cfaf7b67e7f8dfdca42bd9c.tar.gz |
SERVER-15194 Refactor base64::decode Implementation
* Existing check for length as multiple of 4 as-is
* Added check for non-base64 characters on input
* Added check for terminators ('=') midstream
Implicitly in positions 0 and 1 via non-base64 check
Explicitly in positions 2 and 3 via "done" check.
Moved "Alphabet" class into cpp file in anon namespace
as it's an implementation detail and shouldn't be used
by outside classes.
Added base64::validate() method to accomodate BSON's
isBase64String() check.
-rw-r--r-- | src/mongo/bson/json.cpp | 8 | ||||
-rw-r--r-- | src/mongo/util/SConscript | 10 | ||||
-rw-r--r-- | src/mongo/util/base64.cpp | 105 | ||||
-rw-r--r-- | src/mongo/util/base64.h | 48 | ||||
-rw-r--r-- | src/mongo/util/base64_test.cpp | 93 |
5 files changed, 183 insertions, 81 deletions
diff --git a/src/mongo/bson/json.cpp b/src/mongo/bson/json.cpp index d6a94ea08eb..647accf31ba 100644 --- a/src/mongo/bson/json.cpp +++ b/src/mongo/bson/json.cpp @@ -1280,13 +1280,7 @@ bool JParse::isHexString(StringData str) const { bool JParse::isBase64String(StringData str) const { MONGO_JSON_DEBUG("str: " << str); - std::size_t i; - for (i = 0; i < str.size(); i++) { - if (!match(str[i], base64::chars)) { - return false; - } - } - return true; + return base64::validate(str); } bool JParse::isArray() { diff --git a/src/mongo/util/SConscript b/src/mongo/util/SConscript index e348375d4a7..856d15ae6e1 100644 --- a/src/mongo/util/SConscript +++ b/src/mongo/util/SConscript @@ -610,6 +610,16 @@ env.CppUnitTest( ] ) +env.CppUnitTest( + target='base64_test', + source=[ + 'base64_test.cpp', + ], + LIBDEPS=[ + '$BUILD_DIR/mongo/base', + ], +) + if env.TargetOSIs('linux'): env.Library( target='procparser', diff --git a/src/mongo/util/base64.cpp b/src/mongo/util/base64.cpp index 59fdfb63872..2f1f28bfda2 100644 --- a/src/mongo/util/base64.cpp +++ b/src/mongo/util/base64.cpp @@ -32,18 +32,53 @@ #include "mongo/util/base64.h" -#include <sstream> +#include "mongo/util/assert_util.h" + +#include <array> namespace mongo { +using std::begin; +using std::end; using std::string; using std::stringstream; -namespace base64 { +namespace { +constexpr unsigned char kInvalid = -1; + +const class Alphabet { +public: + Alphabet() { + decode.fill(kInvalid); + for (size_t i = 0; i < encode.size(); ++i) { + decode[encode[i]] = i; + } + } + + unsigned char e(std::uint8_t x) const { + return encode[x & 0x3f]; + } + + std::uint8_t d(unsigned char x) const { + auto const c = decode[x]; + uassert(40533, "Invalid base64 character", c != kInvalid); + return c; + } + + bool valid(unsigned char x) const { + return decode[x] != kInvalid; + } -Alphabet alphabet; +private: + StringData encode{ + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"}; + std::array<unsigned char, 256> decode; +} alphabet; +} // namespace -void encode(stringstream& ss, const char* data, int size) { +void base64::encode(stringstream& ss, const char* data, int size) { for (int i = 0; i < size; i += 3) { int left = size - i; const unsigned char* start = (const unsigned char*)data + i; @@ -82,51 +117,59 @@ void encode(stringstream& ss, const char* data, int size) { } -string encode(const char* data, int size) { +string base64::encode(const char* data, int size) { stringstream ss; encode(ss, data, size); return ss.str(); } -string encode(const string& s) { +string base64::encode(const string& s) { return encode(s.c_str(), s.size()); } -void decode(stringstream& ss, const string& s) { +void base64::decode(stringstream& ss, const string& s) { uassert(10270, "invalid base64", s.size() % 4 == 0); - const unsigned char* data = (const unsigned char*)s.c_str(); - int size = s.size(); - - unsigned char buf[3]; - for (int i = 0; i < size; i += 4) { - const unsigned char* start = data + i; - buf[0] = - ((alphabet.decode[start[0]] << 2) & 0xFC) | ((alphabet.decode[start[1]] >> 4) & 0x3); - buf[1] = - ((alphabet.decode[start[1]] << 4) & 0xF0) | ((alphabet.decode[start[2]] >> 2) & 0xF); - buf[2] = ((alphabet.decode[start[2]] << 6) & 0xC0) | ((alphabet.decode[start[3]] & 0x3F)); - - int len = 3; - if (start[3] == '=') { - len = 2; - if (start[2] == '=') { - len = 1; + auto const data = reinterpret_cast<const unsigned char*>(s.c_str()); + auto const size = s.size(); + bool done = false; + + for (size_t i = 0; i < size; i += 4) { + uassert( + 40534, "Invalid Base64 stream. Additional data following terminating sequence.", !done); + auto const start = data + i; + done = (start[2] == '=') || (start[3] == '='); + + ss << (char)(((alphabet.d(start[0]) << 2) & 0xFC) | ((alphabet.d(start[1]) >> 4) & 0x3)); + if (start[2] != '=') { + ss << (char)(((alphabet.d(start[1]) << 4) & 0xF0) | + ((alphabet.d(start[2]) >> 2) & 0xF)); + if (!done) { + ss << (char)(((alphabet.d(start[2]) << 6) & 0xC0) | + ((alphabet.d(start[3]) & 0x3F))); } } - ss.write((const char*)buf, len); } } -string decode(const string& s) { +string base64::decode(const string& s) { stringstream ss; decode(ss, s); return ss.str(); } -const char* chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/="; -} +bool base64::validate(const StringData s) { + if (s.size() % 4) { + return false; + } + if (s.empty()) { + return true; + } + + auto const unwindTerminator = [](auto it) { return (*(it - 1) == '=') ? (it - 1) : it; }; + auto const e = unwindTerminator(unwindTerminator(end(s))); + + return e == std::find_if(begin(s), e, [](const char ch) { return !alphabet.valid(ch); }); } + +} // namespace mongo diff --git a/src/mongo/util/base64.h b/src/mongo/util/base64.h index 740811936ff..a9f3ea93b57 100644 --- a/src/mongo/util/base64.h +++ b/src/mongo/util/base64.h @@ -29,51 +29,14 @@ #pragma once -#include <iosfwd> -#include <memory> +#include <sstream> #include <string> -#include "mongo/util/assert_util.h" +#include "mongo/base/string_data.h" namespace mongo { namespace base64 { -class Alphabet { -public: - Alphabet() - : encode((unsigned char*) - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789" - "+/") - , decode(new unsigned char[257]) { - memset(decode.get(), 0, 256); - for (int i = 0; i < 64; i++) { - decode[encode[i]] = i; - } - - test(); - } - void test() { - verify(strlen((char*)encode) == 64); - for (int i = 0; i < 26; i++) - verify(encode[i] == toupper(encode[i + 26])); - } - - char e(int x) { - return encode[x & 0x3f]; - } - -private: - const unsigned char* encode; - -public: - std::unique_ptr<unsigned char[]> decode; -}; - -extern Alphabet alphabet; - - void encode(std::stringstream& ss, const char* data, int size); std::string encode(const char* data, int size); std::string encode(const std::string& s); @@ -81,8 +44,7 @@ std::string encode(const std::string& s); void decode(std::stringstream& ss, const std::string& s); std::string decode(const std::string& s); -extern const char* chars; +bool validate(StringData); -void testAlphabet(); -} -} +} // namespace base64 +} // namespace mongo diff --git a/src/mongo/util/base64_test.cpp b/src/mongo/util/base64_test.cpp new file mode 100644 index 00000000000..ab347da87cc --- /dev/null +++ b/src/mongo/util/base64_test.cpp @@ -0,0 +1,93 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/unittest/unittest.h" +#include "mongo/util/base64.h" + +namespace mongo { +namespace { + +TEST(Base64Test, transcode) { + const struct { + std::string plain; + std::string encoded; + } tests[] = { + {"", ""}, + {"a", "YQ=="}, + {"aa", "YWE="}, + {"aaa", "YWFh"}, + {"aaaa", "YWFhYQ=="}, + + {"A", "QQ=="}, + {"AA", "QUE="}, + {"AAA", "QUFB"}, + {"AAAA", "QUFBQQ=="}, + + {"The quick brown fox jumped over the lazy dog.", + "VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wZWQgb3ZlciB0aGUgbGF6eSBkb2cu"}, + {std::string("\0\1\2\3\4\5\6\7", 8), "AAECAwQFBgc="}, + {std::string("\0\277\1\276\2\275", 6), "AL8BvgK9"}, + }; + + for (auto const& t : tests) { + ASSERT_TRUE(base64::validate(t.encoded)); + + ASSERT_EQUALS(base64::encode(t.plain), t.encoded); + ASSERT_EQUALS(base64::decode(t.encoded), t.plain); + } +} + +TEST(Base64Test, parseFail) { + const struct { + std::string encoded; + int code; + } tests[] = { + {"BadLength", 10270}, + {"Has Whitespace==", 40533}, + {"Hasbadchar$=", 40533}, + {"Hasbadchar\xFF=", 40533}, + {"Hasbadcahr\t=", 40533}, + {"too=soon", 40534}, + }; + + for (auto const& t : tests) { + ASSERT_FALSE(base64::validate(t.encoded)); + + try { + base64::decode(t.encoded); + ASSERT_TRUE(false); + } catch (const UserException& e) { + ASSERT_EQ(e.getCode(), t.code); + } + } +} + +} // namespace +} // namespace mongo |