summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Thomson <martin.thomson@gmail.com>2017-11-28 15:56:47 +1100
committerMartin Thomson <martin.thomson@gmail.com>2017-11-28 15:56:47 +1100
commit1672e50ebff00b4fe09ce1c1a45a252a918de58e (patch)
tree1ef5f9f504d99889c258cebffcbf3304627ba464
parente9f462476f3f4140e214db0bbb29c4822fb334d2 (diff)
downloadnss-hg-1672e50ebff00b4fe09ce1c1a45a252a918de58e.tar.gz
Bug 1423018 - Retain write polling when 0-RTT is enabled, r=ekr
This is a nasty one. I was having trouble with tstclnt when testing against other implementations. It would hang. I made similar changes to that for kazuho and picotls, where the input file is read for every connection. So, I did that here too and it also worked nicely. In order to get 0-RTT working though, I needed to teach ssl_Poll about 0-RTT. Luckily, there was code already there for that and it just needed a tweak. The only thing I ran into here was that boringssl (the server I was using to test against here), was too fast. By the time we had written out the ClientHello, it had produced a response and we would complete the handshake before leaving the handshake loop in ssl3_Do1stHandshake(). That meant that we would never actually send any 0-RTT data, either before or after this patch. Adding a sleep(1) to the handshake in boringssl did the trick and I could show that we can send data before the handshake completes. Nothing really actionable here unless you can think of ways to make our handshake more performant. Mostly just information. Separately, people have noticed that tstclnt writes 0-RTT after the first round trip. That's a symptom of not retaining the poll on write.
-rw-r--r--cmd/tstclnt/tstclnt.c168
-rw-r--r--lib/ssl/sslsecur.c17
-rw-r--r--lib/ssl/sslsock.c16
3 files changed, 96 insertions, 105 deletions
diff --git a/cmd/tstclnt/tstclnt.c b/cmd/tstclnt/tstclnt.c
index 38cbe94b4..f6956596b 100644
--- a/cmd/tstclnt/tstclnt.c
+++ b/cmd/tstclnt/tstclnt.c
@@ -51,6 +51,7 @@
#define MAX_WAIT_FOR_SERVER 600
#define WAIT_INTERVAL 100
+#define ZERO_RTT_MAX (2 << 16)
#define EXIT_CODE_HANDSHAKE_FAILED 254
@@ -99,6 +100,7 @@ int renegotiationsDone = 0;
PRBool initializedServerSessionCache = PR_FALSE;
static char *progName;
+static const char *requestFile;
secuPWData pwdata = { PW_NONE, 0 };
@@ -711,12 +713,18 @@ void
thread_main(void *arg)
{
PRFileDesc *ps = (PRFileDesc *)arg;
- PRFileDesc *std_in = PR_GetSpecialFD(PR_StandardInput);
+ PRFileDesc *std_in;
int wc, rc;
char buf[256];
+ if (requestFile) {
+ std_in = PR_Open(requestFile, PR_RDONLY, 0);
+ } else {
+ std_in = PR_GetSpecialFD(PR_StandardInput);
+ }
+
#ifdef WIN32
- {
+ if (!requestFile) {
/* Put stdin into O_BINARY mode
** or else incoming \r\n's will become \n's.
*/
@@ -737,6 +745,9 @@ thread_main(void *arg)
wc = PR_Send(ps, buf, rc, 0, maxInterval);
} while (wc == rc);
PR_Close(ps);
+ if (requestFile) {
+ PR_Close(std_in);
+ }
}
#endif
@@ -915,22 +926,22 @@ char *hs1SniHostName = NULL;
char *hs2SniHostName = NULL;
PRUint16 portno = 443;
int override = 0;
-char *requestString = NULL;
-PRInt32 requestStringLen = 0;
-PRBool requestSent = PR_FALSE;
PRBool enableZeroRtt = PR_FALSE;
+PRUint8 *zeroRttData;
+unsigned int zeroRttLen = 0;
PRBool enableAltServerHello = PR_FALSE;
PRBool useDTLS = PR_FALSE;
PRBool actAsServer = PR_FALSE;
PRBool stopAfterHandshake = PR_FALSE;
PRBool requestToExit = PR_FALSE;
char *versionString = NULL;
+PRBool handshakeComplete = PR_FALSE;
static int
-writeBytesToServer(PRFileDesc *s, const char *buf, int nb)
+writeBytesToServer(PRFileDesc *s, const PRUint8 *buf, int nb)
{
SECStatus rv;
- const char *bufp = buf;
+ const PRUint8 *bufp = buf;
PRPollDesc pollDesc;
pollDesc.in_flags = PR_POLL_WRITE | PR_POLL_EXCEPT;
@@ -944,12 +955,20 @@ writeBytesToServer(PRFileDesc *s, const char *buf, int nb)
if (cc < 0) {
PRErrorCode err = PR_GetError();
if (err != PR_WOULD_BLOCK_ERROR) {
- SECU_PrintError(progName,
- "write to SSL socket failed");
+ SECU_PrintError(progName, "write to SSL socket failed");
return 254;
}
cc = 0;
}
+ FPRINTF(stderr, "%s: %d bytes written\n", progName, cc);
+ if (enableZeroRtt && !handshakeComplete) {
+ if (zeroRttLen + cc > ZERO_RTT_MAX) {
+ SECU_PrintError(progName, "too much early data to save");
+ return -1;
+ }
+ PORT_Memcpy(zeroRttData + zeroRttLen, bufp, cc);
+ zeroRttLen += cc;
+ }
bufp += cc;
nb -= cc;
if (nb <= 0)
@@ -969,8 +988,7 @@ writeBytesToServer(PRFileDesc *s, const char *buf, int nb)
progName);
cc = PR_Poll(&pollDesc, 1, PR_INTERVAL_NO_TIMEOUT);
if (cc < 0) {
- SECU_PrintError(progName,
- "PR_Poll failed");
+ SECU_PrintError(progName, "PR_Poll failed");
return -1;
}
FPRINTF(stderr,
@@ -993,7 +1011,7 @@ handshakeCallback(PRFileDesc *fd, void *client_data)
SSL_ReHandshake(fd, (renegotiationsToDo < 2));
++renegotiationsDone;
}
- if (requestString && requestSent) {
+ if (zeroRttLen) {
/* This data was sent in 0-RTT. */
SSLChannelInfo info;
SECStatus rv;
@@ -1003,17 +1021,18 @@ handshakeCallback(PRFileDesc *fd, void *client_data)
return;
if (!info.earlyDataAccepted) {
- FPRINTF(stderr, "Early data rejected. Re-sending\n");
- writeBytesToServer(fd, requestString, requestStringLen);
+ FPRINTF(stderr, "Early data rejected. Re-sending %d bytes\n",
+ zeroRttLen);
+ writeBytesToServer(fd, zeroRttData, zeroRttLen);
+ zeroRttLen = 0;
}
}
if (stopAfterHandshake) {
requestToExit = PR_TRUE;
}
+ handshakeComplete = PR_TRUE;
}
-#define REQUEST_WAITING (requestString && !requestSent)
-
static SECStatus
installServerCertificate(PRFileDesc *s, char *nick)
{
@@ -1136,13 +1155,12 @@ run()
SECStatus rv;
PRStatus status;
PRInt32 filesReady;
- int npds;
PRFileDesc *s = NULL;
PRFileDesc *std_out;
PRPollDesc pollset[2];
PRBool wrStarted = PR_FALSE;
- requestSent = PR_FALSE;
+ handshakeComplete = PR_FALSE;
/* Create socket */
if (useDTLS) {
@@ -1393,7 +1411,6 @@ run()
/* Try to connect to the server */
rv = connectToServer(s, pollset);
if (rv != SECSuccess) {
- ;
error = 1;
goto done;
}
@@ -1405,13 +1422,18 @@ run()
pollset[SSOCK_FD].in_flags |= (clientSpeaksFirst ? 0 : PR_POLL_READ);
else
pollset[SSOCK_FD].in_flags |= PR_POLL_READ;
- pollset[STDIN_FD].fd = PR_GetSpecialFD(PR_StandardInput);
- if (!REQUEST_WAITING) {
- pollset[STDIN_FD].in_flags = PR_POLL_READ;
- npds = 2;
+ if (requestFile) {
+ pollset[STDIN_FD].fd = PR_Open(requestFile, PR_RDONLY, 0);
+ if (!pollset[STDIN_FD].fd) {
+ fprintf(stderr, "%s: unable to open input file: %s\n",
+ progName, requestFile);
+ error = 1;
+ goto done;
+ }
} else {
- npds = 1;
+ pollset[STDIN_FD].fd = PR_GetSpecialFD(PR_StandardInput);
}
+ pollset[STDIN_FD].in_flags = PR_POLL_READ;
std_out = PR_GetSpecialFD(PR_StandardOutput);
#if defined(WIN32) || defined(OS2)
@@ -1457,10 +1479,9 @@ run()
requestToExit = PR_FALSE;
FPRINTF(stderr, "%s: ready...\n", progName);
while (!requestToExit &&
- ((pollset[SSOCK_FD].in_flags | pollset[STDIN_FD].in_flags) ||
- REQUEST_WAITING)) {
- char buf[4000]; /* buffer for stdin */
- int nb; /* num bytes read from stdin. */
+ (pollset[SSOCK_FD].in_flags || pollset[STDIN_FD].in_flags)) {
+ PRUint8 buf[4000]; /* buffer for stdin */
+ int nb; /* num bytes read from stdin. */
rv = restartHandshakeAfterServerCertIfNeeded(s, &serverCertAuth,
override);
@@ -1474,7 +1495,8 @@ run()
pollset[STDIN_FD].out_flags = 0;
FPRINTF(stderr, "%s: about to call PR_Poll !\n", progName);
- filesReady = PR_Poll(pollset, npds, PR_INTERVAL_NO_TIMEOUT);
+ filesReady = PR_Poll(pollset, PR_ARRAY_SIZE(pollset),
+ PR_INTERVAL_NO_TIMEOUT);
if (filesReady < 0) {
SECU_PrintError(progName, "select failed");
error = 1;
@@ -1496,14 +1518,6 @@ run()
"%s: PR_Poll returned 0x%02x for socket out_flags.\n",
progName, pollset[SSOCK_FD].out_flags);
}
- if (REQUEST_WAITING) {
- error = writeBytesToServer(s, requestString, requestStringLen);
- if (error) {
- goto done;
- }
- requestSent = PR_TRUE;
- pollset[SSOCK_FD].in_flags = PR_POLL_READ;
- }
if (pollset[STDIN_FD].out_flags & PR_POLL_READ) {
/* Read from stdin and write to socket */
nb = PR_Read(pollset[STDIN_FD].fd, buf, sizeof(buf));
@@ -1517,6 +1531,8 @@ run()
} else if (nb == 0) {
/* EOF on stdin, stop polling stdin for read. */
pollset[STDIN_FD].in_flags = 0;
+ if (actAsServer)
+ requestToExit = PR_TRUE;
} else {
error = writeBytesToServer(s, buf, nb);
if (error) {
@@ -1531,12 +1547,12 @@ run()
"%s: PR_Poll returned 0x%02x for socket out_flags.\n",
progName, pollset[SSOCK_FD].out_flags);
}
- if ((pollset[SSOCK_FD].out_flags & PR_POLL_READ) ||
- (pollset[SSOCK_FD].out_flags & PR_POLL_ERR)
#ifdef PR_POLL_HUP
- || (pollset[SSOCK_FD].out_flags & PR_POLL_HUP)
+#define POLL_RECV_FLAGS (PR_POLL_READ | PR_POLL_ERR | PR_POLL_HUP)
+#else
+#define POLL_RECV_FLAGS (PR_POLL_READ | PR_POLL_ERR)
#endif
- ) {
+ if (pollset[SSOCK_FD].out_flags & POLL_RECV_FLAGS) {
/* Read from socket and write to stdout */
nb = PR_Recv(pollset[SSOCK_FD].fd, buf, sizeof buf, 0, maxInterval);
FPRINTF(stderr, "%s: Read from server %d bytes\n", progName, nb);
@@ -1553,7 +1569,7 @@ run()
if (skipProtoHeader != PR_TRUE || wrStarted == PR_TRUE) {
PR_Write(std_out, buf, nb);
} else {
- separateReqHeader(std_out, buf, nb, &wrStarted,
+ separateReqHeader(std_out, (char *)buf, nb, &wrStarted,
&headerSeparatorPtrnId);
}
if (verbose)
@@ -1567,42 +1583,10 @@ done:
if (s) {
PR_Close(s);
}
-
- return error;
-}
-
-PRInt32
-ReadFile(const char *filename, char **data)
-{
- char *ret = NULL;
- char buf[8192];
- unsigned int len = 0;
- PRStatus rv;
-
- PRFileDesc *fd = PR_Open(filename, PR_RDONLY, 0);
- if (!fd)
- return -1;
-
- for (;;) {
- rv = PR_Read(fd, buf, sizeof(buf));
- if (rv < 0) {
- PR_Free(ret);
- return rv;
- }
-
- if (!rv)
- break;
-
- ret = PR_Realloc(ret, len + rv);
- if (!ret) {
- return -1;
- }
- PORT_Memcpy(ret + len, buf, rv);
- len += rv;
+ if (requestFile) {
+ PR_Close(pollset[STDIN_FD].fd);
}
-
- *data = ret;
- return len;
+ return error;
}
int
@@ -1667,11 +1651,7 @@ main(int argc, char **argv)
break;
case 'A':
- requestStringLen = ReadFile(optstate->value, &requestString);
- if (requestStringLen < 0) {
- fprintf(stderr, "Couldn't read file %s\n", optstate->value);
- exit(1);
- }
+ requestFile = PORT_Strdup(optstate->value);
break;
case 'C':
@@ -1777,6 +1757,11 @@ main(int argc, char **argv)
case 'Z':
enableZeroRtt = PR_TRUE;
+ zeroRttData = PORT_ZAlloc(ZERO_RTT_MAX);
+ if (!zeroRttData) {
+ fprintf(stderr, "Unable to allocate buffer for 0-RTT\n");
+ exit(1);
+ }
break;
case 'a':
@@ -2059,20 +2044,13 @@ done:
PR_Close(s);
}
- if (hs1SniHostName) {
- PORT_Free(hs1SniHostName);
- }
- if (hs2SniHostName) {
- PORT_Free(hs2SniHostName);
- }
- if (nickname) {
- PORT_Free(nickname);
- }
- if (pwdata.data) {
- PORT_Free(pwdata.data);
- }
+ PORT_Free((void *)requestFile);
+ PORT_Free(hs1SniHostName);
+ PORT_Free(hs2SniHostName);
+ PORT_Free(nickname);
+ PORT_Free(pwdata.data);
PORT_Free(host);
- PORT_Free(requestString);
+ PORT_Free(zeroRttData);
if (enabledGroups) {
PORT_Free(enabledGroups);
diff --git a/lib/ssl/sslsecur.c b/lib/ssl/sslsecur.c
index d3424a7ad..a1d389214 100644
--- a/lib/ssl/sslsecur.c
+++ b/lib/ssl/sslsecur.c
@@ -922,21 +922,30 @@ ssl_SecureSend(sslSocket *ss, const unsigned char *buf, int len, int flags)
*/
if (!ss->firstHsDone) {
PRBool allowEarlySend = PR_FALSE;
+ PRBool firstClientWrite = PR_FALSE;
ssl_Get1stHandshakeLock(ss);
- if (ss->opt.enableFalseStart ||
- (ss->opt.enable0RttData && !ss->sec.isServer)) {
+ /* The client can sometimes send before the handshake is fully
+ * complete. In TLS 1.2: false start; in TLS 1.3: 0-RTT. */
+ if (!ss->sec.isServer &&
+ (ss->opt.enableFalseStart || ss->opt.enable0RttData)) {
ssl_GetSSL3HandshakeLock(ss);
- /* The client can sometimes send before the handshake is fully
- * complete. In TLS 1.2: false start; in TLS 1.3: 0-RTT. */
zeroRtt = ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted;
allowEarlySend = ss->ssl3.hs.canFalseStart || zeroRtt;
+ firstClientWrite = ss->ssl3.hs.ws == idle_handshake;
ssl_ReleaseSSL3HandshakeLock(ss);
}
if (!allowEarlySend && ss->handshake) {
rv = ssl_Do1stHandshake(ss);
}
+ if (firstClientWrite) {
+ /* Wait until after sending ClientHello and double-check 0-RTT. */
+ ssl_GetSSL3HandshakeLock(ss);
+ zeroRtt = ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
+ ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted;
+ ssl_ReleaseSSL3HandshakeLock(ss);
+ }
ssl_Release1stHandshakeLock(ss);
}
diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c
index 5b98a3de4..286b35e7d 100644
--- a/lib/ssl/sslsock.c
+++ b/lib/ssl/sslsock.c
@@ -3039,26 +3039,27 @@ ssl_Poll(PRFileDesc *fd, PRInt16 how_flags, PRInt16 *p_out_flags)
} else { /* handshaking as server */
new_flags |= PR_POLL_READ;
}
- } else
+ } else if (ss->lastWriteBlocked) {
/* First handshake is in progress */
- if (ss->lastWriteBlocked) {
if (new_flags & PR_POLL_READ) {
/* The caller is waiting for data to be received,
** but the initial handshake is blocked on write, or the
** client's first handshake record has not been written.
** The code should select on write, not read.
*/
- new_flags ^= PR_POLL_READ; /* don't select on read. */
+ new_flags &= ~PR_POLL_READ; /* don't select on read. */
new_flags |= PR_POLL_WRITE; /* do select on write. */
}
} else if (new_flags & PR_POLL_WRITE) {
/* The caller is trying to write, but the handshake is
** blocked waiting for data to read, and the first
** handshake has been sent. So do NOT to poll on write
- ** unless we did false start.
+ ** unless we did false start or we are doing 0-RTT.
*/
- if (!ss->ssl3.hs.canFalseStart) {
- new_flags ^= PR_POLL_WRITE; /* don't select on write. */
+ if (!(ss->ssl3.hs.canFalseStart ||
+ ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
+ ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted)) {
+ new_flags &= ~PR_POLL_WRITE; /* don't select on write. */
}
new_flags |= PR_POLL_READ; /* do select on read. */
}
@@ -3098,6 +3099,9 @@ ssl_Poll(PRFileDesc *fd, PRInt16 how_flags, PRInt16 *p_out_flags)
}
}
+ SSL_TRC(20, ("%d: SSL[%d]: ssl_Poll flags %x -> %x",
+ SSL_GETPID(), fd, how_flags, new_flags));
+
if (new_flags && (fd->lower->methods->poll != NULL)) {
PRInt16 lower_out_flags = 0;
PRInt16 lower_new_flags;