summaryrefslogtreecommitdiff
path: root/fuzz/tls_client_target.cc
blob: d1dda12d4d4ae04243fc68e5bb9791dc2215c77a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#include <assert.h>
#include <stdint.h>
#include <memory>

#include "blapi.h"
#include "prinit.h"
#include "ssl.h"

#include "shared.h"
#include "tls_client_config.h"
#include "tls_client_socket.h"

static PRStatus EnableAllProtocolVersions() {
  SSLVersionRange supported;

  SECStatus rv = SSL_VersionRangeGetSupported(ssl_variant_stream, &supported);
  assert(rv == SECSuccess);

  rv = SSL_VersionRangeSetDefault(ssl_variant_stream, &supported);
  assert(rv == SECSuccess);

  return PR_SUCCESS;
}

static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, PRBool checksig,
                                     PRBool isServer) {
  assert(!isServer);
  auto config = reinterpret_cast<ClientConfig*>(arg);
  return config->FailCertificateAuthentication() ? SECFailure : SECSuccess;
}

static void SetSocketOptions(PRFileDesc* fd,
                             std::unique_ptr<ClientConfig>& config) {
  // Disable session cache for now.
  SECStatus rv = SSL_OptionSet(fd, SSL_NO_CACHE, true);
  assert(rv == SECSuccess);

  rv = SSL_OptionSet(fd, SSL_ENABLE_EXTENDED_MASTER_SECRET,
                     config->EnableExtendedMasterSecret());
  assert(rv == SECSuccess);

  rv = SSL_OptionSet(fd, SSL_REQUIRE_DH_NAMED_GROUPS,
                     config->RequireDhNamedGroups());
  assert(rv == SECSuccess);

  rv = SSL_OptionSet(fd, SSL_ENABLE_FALSE_START, config->EnableFalseStart());
  assert(rv == SECSuccess);

  rv =
      SSL_OptionSet(fd, SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_UNRESTRICTED);
  assert(rv == SECSuccess);
}

static void EnableAllCipherSuites(PRFileDesc* fd) {
  for (uint16_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
    SECStatus rv = SSL_CipherPrefSet(fd, SSL_ImplementedCiphers[i], true);
    assert(rv == SECSuccess);
  }
}

// This is only called when we set SSL_ENABLE_FALSE_START=1,
// so we can always just set *canFalseStart=true.
static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
                                       PRBool* canFalseStart) {
  *canFalseStart = true;
  return SECSuccess;
}

static void SetupCallbacks(PRFileDesc* fd, ClientConfig* config) {
  SECStatus rv = SSL_AuthCertificateHook(fd, AuthCertificateHook, config);
  assert(rv == SECSuccess);

  rv = SSL_SetCanFalseStartCallback(fd, CanFalseStartCallback, nullptr);
  assert(rv == SECSuccess);
}

static void DoHandshake(PRFileDesc* fd) {
  SECStatus rv = SSL_ResetHandshake(fd, false /* asServer */);
  assert(rv == SECSuccess);

  do {
    rv = SSL_ForceHandshake(fd);
  } while (rv != SECSuccess && PR_GetError() == PR_WOULD_BLOCK_ERROR);

  // If the handshake succeeds, let's read some data from the server, if any.
  if (rv == SECSuccess) {
    uint8_t block[1024];
    int32_t nb;

    // Read application data and echo it back.
    while ((nb = PR_Read(fd, block, sizeof(block))) > 0) {
      PR_Write(fd, block, nb);
    }
  }
}

extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t len) {
  static std::unique_ptr<NSSDatabase> db(new NSSDatabase());
  assert(db != nullptr);

  EnableAllProtocolVersions();
  std::unique_ptr<ClientConfig> config(new ClientConfig(data, len));

#ifdef UNSAFE_FUZZER_MODE
  // Reset the RNG state.
  SECStatus rv = RNG_ResetForFuzzing();
  assert(rv == SECSuccess);
#endif

  // Create and import dummy socket.
  std::unique_ptr<DummyPrSocket> socket(new DummyPrSocket(data, len));
  static PRDescIdentity id = PR_GetUniqueIdentity("fuzz-client");
  ScopedPRFileDesc fd(DummyIOLayerMethods::CreateFD(id, socket.get()));
  PRFileDesc* ssl_fd = SSL_ImportFD(nullptr, fd.get());
  assert(ssl_fd == fd.get());

  // Probably not too important for clients.
  SSL_SetURL(ssl_fd, "server");

  SetSocketOptions(ssl_fd, config);
  EnableAllCipherSuites(ssl_fd);
  SetupCallbacks(ssl_fd, config.get());
  DoHandshake(ssl_fd);

  return 0;
}