summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSara Golemon <sara.golemon@mongodb.com>2017-06-27 10:00:42 -0400
committerSara Golemon <sara.golemon@mongodb.com>2017-07-06 16:36:35 -0400
commit4b222edf455a34667cfaf7b67e7f8dfdca42bd9c (patch)
tree5f468e532fa02650803ea20036c7b9a9161ce96f /src
parentf0b95cc0c48fe242edbc9c7958f9df0a34813e78 (diff)
downloadmongo-4b222edf455a34667cfaf7b67e7f8dfdca42bd9c.tar.gz
SERVER-15194 Refactor base64::decode Implementation
* Existing check for length as multiple of 4 as-is * Added check for non-base64 characters on input * Added check for terminators ('=') midstream Implicitly in positions 0 and 1 via non-base64 check Explicitly in positions 2 and 3 via "done" check. Moved "Alphabet" class into cpp file in anon namespace as it's an implementation detail and shouldn't be used by outside classes. Added base64::validate() method to accomodate BSON's isBase64String() check.
Diffstat (limited to 'src')
-rw-r--r--src/mongo/bson/json.cpp8
-rw-r--r--src/mongo/util/SConscript10
-rw-r--r--src/mongo/util/base64.cpp105
-rw-r--r--src/mongo/util/base64.h48
-rw-r--r--src/mongo/util/base64_test.cpp93
5 files changed, 183 insertions, 81 deletions
diff --git a/src/mongo/bson/json.cpp b/src/mongo/bson/json.cpp
index d6a94ea08eb..647accf31ba 100644
--- a/src/mongo/bson/json.cpp
+++ b/src/mongo/bson/json.cpp
@@ -1280,13 +1280,7 @@ bool JParse::isHexString(StringData str) const {
bool JParse::isBase64String(StringData str) const {
MONGO_JSON_DEBUG("str: " << str);
- std::size_t i;
- for (i = 0; i < str.size(); i++) {
- if (!match(str[i], base64::chars)) {
- return false;
- }
- }
- return true;
+ return base64::validate(str);
}
bool JParse::isArray() {
diff --git a/src/mongo/util/SConscript b/src/mongo/util/SConscript
index e348375d4a7..856d15ae6e1 100644
--- a/src/mongo/util/SConscript
+++ b/src/mongo/util/SConscript
@@ -610,6 +610,16 @@ env.CppUnitTest(
]
)
+env.CppUnitTest(
+ target='base64_test',
+ source=[
+ 'base64_test.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/base',
+ ],
+)
+
if env.TargetOSIs('linux'):
env.Library(
target='procparser',
diff --git a/src/mongo/util/base64.cpp b/src/mongo/util/base64.cpp
index 59fdfb63872..2f1f28bfda2 100644
--- a/src/mongo/util/base64.cpp
+++ b/src/mongo/util/base64.cpp
@@ -32,18 +32,53 @@
#include "mongo/util/base64.h"
-#include <sstream>
+#include "mongo/util/assert_util.h"
+
+#include <array>
namespace mongo {
+using std::begin;
+using std::end;
using std::string;
using std::stringstream;
-namespace base64 {
+namespace {
+constexpr unsigned char kInvalid = -1;
+
+const class Alphabet {
+public:
+ Alphabet() {
+ decode.fill(kInvalid);
+ for (size_t i = 0; i < encode.size(); ++i) {
+ decode[encode[i]] = i;
+ }
+ }
+
+ unsigned char e(std::uint8_t x) const {
+ return encode[x & 0x3f];
+ }
+
+ std::uint8_t d(unsigned char x) const {
+ auto const c = decode[x];
+ uassert(40533, "Invalid base64 character", c != kInvalid);
+ return c;
+ }
+
+ bool valid(unsigned char x) const {
+ return decode[x] != kInvalid;
+ }
-Alphabet alphabet;
+private:
+ StringData encode{
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789+/"};
+ std::array<unsigned char, 256> decode;
+} alphabet;
+} // namespace
-void encode(stringstream& ss, const char* data, int size) {
+void base64::encode(stringstream& ss, const char* data, int size) {
for (int i = 0; i < size; i += 3) {
int left = size - i;
const unsigned char* start = (const unsigned char*)data + i;
@@ -82,51 +117,59 @@ void encode(stringstream& ss, const char* data, int size) {
}
-string encode(const char* data, int size) {
+string base64::encode(const char* data, int size) {
stringstream ss;
encode(ss, data, size);
return ss.str();
}
-string encode(const string& s) {
+string base64::encode(const string& s) {
return encode(s.c_str(), s.size());
}
-void decode(stringstream& ss, const string& s) {
+void base64::decode(stringstream& ss, const string& s) {
uassert(10270, "invalid base64", s.size() % 4 == 0);
- const unsigned char* data = (const unsigned char*)s.c_str();
- int size = s.size();
-
- unsigned char buf[3];
- for (int i = 0; i < size; i += 4) {
- const unsigned char* start = data + i;
- buf[0] =
- ((alphabet.decode[start[0]] << 2) & 0xFC) | ((alphabet.decode[start[1]] >> 4) & 0x3);
- buf[1] =
- ((alphabet.decode[start[1]] << 4) & 0xF0) | ((alphabet.decode[start[2]] >> 2) & 0xF);
- buf[2] = ((alphabet.decode[start[2]] << 6) & 0xC0) | ((alphabet.decode[start[3]] & 0x3F));
-
- int len = 3;
- if (start[3] == '=') {
- len = 2;
- if (start[2] == '=') {
- len = 1;
+ auto const data = reinterpret_cast<const unsigned char*>(s.c_str());
+ auto const size = s.size();
+ bool done = false;
+
+ for (size_t i = 0; i < size; i += 4) {
+ uassert(
+ 40534, "Invalid Base64 stream. Additional data following terminating sequence.", !done);
+ auto const start = data + i;
+ done = (start[2] == '=') || (start[3] == '=');
+
+ ss << (char)(((alphabet.d(start[0]) << 2) & 0xFC) | ((alphabet.d(start[1]) >> 4) & 0x3));
+ if (start[2] != '=') {
+ ss << (char)(((alphabet.d(start[1]) << 4) & 0xF0) |
+ ((alphabet.d(start[2]) >> 2) & 0xF));
+ if (!done) {
+ ss << (char)(((alphabet.d(start[2]) << 6) & 0xC0) |
+ ((alphabet.d(start[3]) & 0x3F)));
}
}
- ss.write((const char*)buf, len);
}
}
-string decode(const string& s) {
+string base64::decode(const string& s) {
stringstream ss;
decode(ss, s);
return ss.str();
}
-const char* chars =
- "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "abcdefghijklmnopqrstuvwxyz"
- "0123456789+/=";
-}
+bool base64::validate(const StringData s) {
+ if (s.size() % 4) {
+ return false;
+ }
+ if (s.empty()) {
+ return true;
+ }
+
+ auto const unwindTerminator = [](auto it) { return (*(it - 1) == '=') ? (it - 1) : it; };
+ auto const e = unwindTerminator(unwindTerminator(end(s)));
+
+ return e == std::find_if(begin(s), e, [](const char ch) { return !alphabet.valid(ch); });
}
+
+} // namespace mongo
diff --git a/src/mongo/util/base64.h b/src/mongo/util/base64.h
index 740811936ff..a9f3ea93b57 100644
--- a/src/mongo/util/base64.h
+++ b/src/mongo/util/base64.h
@@ -29,51 +29,14 @@
#pragma once
-#include <iosfwd>
-#include <memory>
+#include <sstream>
#include <string>
-#include "mongo/util/assert_util.h"
+#include "mongo/base/string_data.h"
namespace mongo {
namespace base64 {
-class Alphabet {
-public:
- Alphabet()
- : encode((unsigned char*)
- "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "abcdefghijklmnopqrstuvwxyz"
- "0123456789"
- "+/")
- , decode(new unsigned char[257]) {
- memset(decode.get(), 0, 256);
- for (int i = 0; i < 64; i++) {
- decode[encode[i]] = i;
- }
-
- test();
- }
- void test() {
- verify(strlen((char*)encode) == 64);
- for (int i = 0; i < 26; i++)
- verify(encode[i] == toupper(encode[i + 26]));
- }
-
- char e(int x) {
- return encode[x & 0x3f];
- }
-
-private:
- const unsigned char* encode;
-
-public:
- std::unique_ptr<unsigned char[]> decode;
-};
-
-extern Alphabet alphabet;
-
-
void encode(std::stringstream& ss, const char* data, int size);
std::string encode(const char* data, int size);
std::string encode(const std::string& s);
@@ -81,8 +44,7 @@ std::string encode(const std::string& s);
void decode(std::stringstream& ss, const std::string& s);
std::string decode(const std::string& s);
-extern const char* chars;
+bool validate(StringData);
-void testAlphabet();
-}
-}
+} // namespace base64
+} // namespace mongo
diff --git a/src/mongo/util/base64_test.cpp b/src/mongo/util/base64_test.cpp
new file mode 100644
index 00000000000..ab347da87cc
--- /dev/null
+++ b/src/mongo/util/base64_test.cpp
@@ -0,0 +1,93 @@
+/**
+ * Copyright (C) 2017 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/base64.h"
+
+namespace mongo {
+namespace {
+
+TEST(Base64Test, transcode) {
+ const struct {
+ std::string plain;
+ std::string encoded;
+ } tests[] = {
+ {"", ""},
+ {"a", "YQ=="},
+ {"aa", "YWE="},
+ {"aaa", "YWFh"},
+ {"aaaa", "YWFhYQ=="},
+
+ {"A", "QQ=="},
+ {"AA", "QUE="},
+ {"AAA", "QUFB"},
+ {"AAAA", "QUFBQQ=="},
+
+ {"The quick brown fox jumped over the lazy dog.",
+ "VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wZWQgb3ZlciB0aGUgbGF6eSBkb2cu"},
+ {std::string("\0\1\2\3\4\5\6\7", 8), "AAECAwQFBgc="},
+ {std::string("\0\277\1\276\2\275", 6), "AL8BvgK9"},
+ };
+
+ for (auto const& t : tests) {
+ ASSERT_TRUE(base64::validate(t.encoded));
+
+ ASSERT_EQUALS(base64::encode(t.plain), t.encoded);
+ ASSERT_EQUALS(base64::decode(t.encoded), t.plain);
+ }
+}
+
+TEST(Base64Test, parseFail) {
+ const struct {
+ std::string encoded;
+ int code;
+ } tests[] = {
+ {"BadLength", 10270},
+ {"Has Whitespace==", 40533},
+ {"Hasbadchar$=", 40533},
+ {"Hasbadchar\xFF=", 40533},
+ {"Hasbadcahr\t=", 40533},
+ {"too=soon", 40534},
+ };
+
+ for (auto const& t : tests) {
+ ASSERT_FALSE(base64::validate(t.encoded));
+
+ try {
+ base64::decode(t.encoded);
+ ASSERT_TRUE(false);
+ } catch (const UserException& e) {
+ ASSERT_EQ(e.getCode(), t.code);
+ }
+ }
+}
+
+} // namespace
+} // namespace mongo