diff options
-rw-r--r-- | cmd/tstclnt/tstclnt.c | 168 | ||||
-rw-r--r-- | lib/ssl/sslsecur.c | 17 | ||||
-rw-r--r-- | lib/ssl/sslsock.c | 16 |
3 files changed, 96 insertions, 105 deletions
diff --git a/cmd/tstclnt/tstclnt.c b/cmd/tstclnt/tstclnt.c index 38cbe94b4..f6956596b 100644 --- a/cmd/tstclnt/tstclnt.c +++ b/cmd/tstclnt/tstclnt.c @@ -51,6 +51,7 @@ #define MAX_WAIT_FOR_SERVER 600 #define WAIT_INTERVAL 100 +#define ZERO_RTT_MAX (2 << 16) #define EXIT_CODE_HANDSHAKE_FAILED 254 @@ -99,6 +100,7 @@ int renegotiationsDone = 0; PRBool initializedServerSessionCache = PR_FALSE; static char *progName; +static const char *requestFile; secuPWData pwdata = { PW_NONE, 0 }; @@ -711,12 +713,18 @@ void thread_main(void *arg) { PRFileDesc *ps = (PRFileDesc *)arg; - PRFileDesc *std_in = PR_GetSpecialFD(PR_StandardInput); + PRFileDesc *std_in; int wc, rc; char buf[256]; + if (requestFile) { + std_in = PR_Open(requestFile, PR_RDONLY, 0); + } else { + std_in = PR_GetSpecialFD(PR_StandardInput); + } + #ifdef WIN32 - { + if (!requestFile) { /* Put stdin into O_BINARY mode ** or else incoming \r\n's will become \n's. */ @@ -737,6 +745,9 @@ thread_main(void *arg) wc = PR_Send(ps, buf, rc, 0, maxInterval); } while (wc == rc); PR_Close(ps); + if (requestFile) { + PR_Close(std_in); + } } #endif @@ -915,22 +926,22 @@ char *hs1SniHostName = NULL; char *hs2SniHostName = NULL; PRUint16 portno = 443; int override = 0; -char *requestString = NULL; -PRInt32 requestStringLen = 0; -PRBool requestSent = PR_FALSE; PRBool enableZeroRtt = PR_FALSE; +PRUint8 *zeroRttData; +unsigned int zeroRttLen = 0; PRBool enableAltServerHello = PR_FALSE; PRBool useDTLS = PR_FALSE; PRBool actAsServer = PR_FALSE; PRBool stopAfterHandshake = PR_FALSE; PRBool requestToExit = PR_FALSE; char *versionString = NULL; +PRBool handshakeComplete = PR_FALSE; static int -writeBytesToServer(PRFileDesc *s, const char *buf, int nb) +writeBytesToServer(PRFileDesc *s, const PRUint8 *buf, int nb) { SECStatus rv; - const char *bufp = buf; + const PRUint8 *bufp = buf; PRPollDesc pollDesc; pollDesc.in_flags = PR_POLL_WRITE | PR_POLL_EXCEPT; @@ -944,12 +955,20 @@ writeBytesToServer(PRFileDesc *s, const char *buf, int nb) if (cc < 0) { PRErrorCode err = PR_GetError(); if (err != PR_WOULD_BLOCK_ERROR) { - SECU_PrintError(progName, - "write to SSL socket failed"); + SECU_PrintError(progName, "write to SSL socket failed"); return 254; } cc = 0; } + FPRINTF(stderr, "%s: %d bytes written\n", progName, cc); + if (enableZeroRtt && !handshakeComplete) { + if (zeroRttLen + cc > ZERO_RTT_MAX) { + SECU_PrintError(progName, "too much early data to save"); + return -1; + } + PORT_Memcpy(zeroRttData + zeroRttLen, bufp, cc); + zeroRttLen += cc; + } bufp += cc; nb -= cc; if (nb <= 0) @@ -969,8 +988,7 @@ writeBytesToServer(PRFileDesc *s, const char *buf, int nb) progName); cc = PR_Poll(&pollDesc, 1, PR_INTERVAL_NO_TIMEOUT); if (cc < 0) { - SECU_PrintError(progName, - "PR_Poll failed"); + SECU_PrintError(progName, "PR_Poll failed"); return -1; } FPRINTF(stderr, @@ -993,7 +1011,7 @@ handshakeCallback(PRFileDesc *fd, void *client_data) SSL_ReHandshake(fd, (renegotiationsToDo < 2)); ++renegotiationsDone; } - if (requestString && requestSent) { + if (zeroRttLen) { /* This data was sent in 0-RTT. */ SSLChannelInfo info; SECStatus rv; @@ -1003,17 +1021,18 @@ handshakeCallback(PRFileDesc *fd, void *client_data) return; if (!info.earlyDataAccepted) { - FPRINTF(stderr, "Early data rejected. Re-sending\n"); - writeBytesToServer(fd, requestString, requestStringLen); + FPRINTF(stderr, "Early data rejected. Re-sending %d bytes\n", + zeroRttLen); + writeBytesToServer(fd, zeroRttData, zeroRttLen); + zeroRttLen = 0; } } if (stopAfterHandshake) { requestToExit = PR_TRUE; } + handshakeComplete = PR_TRUE; } -#define REQUEST_WAITING (requestString && !requestSent) - static SECStatus installServerCertificate(PRFileDesc *s, char *nick) { @@ -1136,13 +1155,12 @@ run() SECStatus rv; PRStatus status; PRInt32 filesReady; - int npds; PRFileDesc *s = NULL; PRFileDesc *std_out; PRPollDesc pollset[2]; PRBool wrStarted = PR_FALSE; - requestSent = PR_FALSE; + handshakeComplete = PR_FALSE; /* Create socket */ if (useDTLS) { @@ -1393,7 +1411,6 @@ run() /* Try to connect to the server */ rv = connectToServer(s, pollset); if (rv != SECSuccess) { - ; error = 1; goto done; } @@ -1405,13 +1422,18 @@ run() pollset[SSOCK_FD].in_flags |= (clientSpeaksFirst ? 0 : PR_POLL_READ); else pollset[SSOCK_FD].in_flags |= PR_POLL_READ; - pollset[STDIN_FD].fd = PR_GetSpecialFD(PR_StandardInput); - if (!REQUEST_WAITING) { - pollset[STDIN_FD].in_flags = PR_POLL_READ; - npds = 2; + if (requestFile) { + pollset[STDIN_FD].fd = PR_Open(requestFile, PR_RDONLY, 0); + if (!pollset[STDIN_FD].fd) { + fprintf(stderr, "%s: unable to open input file: %s\n", + progName, requestFile); + error = 1; + goto done; + } } else { - npds = 1; + pollset[STDIN_FD].fd = PR_GetSpecialFD(PR_StandardInput); } + pollset[STDIN_FD].in_flags = PR_POLL_READ; std_out = PR_GetSpecialFD(PR_StandardOutput); #if defined(WIN32) || defined(OS2) @@ -1457,10 +1479,9 @@ run() requestToExit = PR_FALSE; FPRINTF(stderr, "%s: ready...\n", progName); while (!requestToExit && - ((pollset[SSOCK_FD].in_flags | pollset[STDIN_FD].in_flags) || - REQUEST_WAITING)) { - char buf[4000]; /* buffer for stdin */ - int nb; /* num bytes read from stdin. */ + (pollset[SSOCK_FD].in_flags || pollset[STDIN_FD].in_flags)) { + PRUint8 buf[4000]; /* buffer for stdin */ + int nb; /* num bytes read from stdin. */ rv = restartHandshakeAfterServerCertIfNeeded(s, &serverCertAuth, override); @@ -1474,7 +1495,8 @@ run() pollset[STDIN_FD].out_flags = 0; FPRINTF(stderr, "%s: about to call PR_Poll !\n", progName); - filesReady = PR_Poll(pollset, npds, PR_INTERVAL_NO_TIMEOUT); + filesReady = PR_Poll(pollset, PR_ARRAY_SIZE(pollset), + PR_INTERVAL_NO_TIMEOUT); if (filesReady < 0) { SECU_PrintError(progName, "select failed"); error = 1; @@ -1496,14 +1518,6 @@ run() "%s: PR_Poll returned 0x%02x for socket out_flags.\n", progName, pollset[SSOCK_FD].out_flags); } - if (REQUEST_WAITING) { - error = writeBytesToServer(s, requestString, requestStringLen); - if (error) { - goto done; - } - requestSent = PR_TRUE; - pollset[SSOCK_FD].in_flags = PR_POLL_READ; - } if (pollset[STDIN_FD].out_flags & PR_POLL_READ) { /* Read from stdin and write to socket */ nb = PR_Read(pollset[STDIN_FD].fd, buf, sizeof(buf)); @@ -1517,6 +1531,8 @@ run() } else if (nb == 0) { /* EOF on stdin, stop polling stdin for read. */ pollset[STDIN_FD].in_flags = 0; + if (actAsServer) + requestToExit = PR_TRUE; } else { error = writeBytesToServer(s, buf, nb); if (error) { @@ -1531,12 +1547,12 @@ run() "%s: PR_Poll returned 0x%02x for socket out_flags.\n", progName, pollset[SSOCK_FD].out_flags); } - if ((pollset[SSOCK_FD].out_flags & PR_POLL_READ) || - (pollset[SSOCK_FD].out_flags & PR_POLL_ERR) #ifdef PR_POLL_HUP - || (pollset[SSOCK_FD].out_flags & PR_POLL_HUP) +#define POLL_RECV_FLAGS (PR_POLL_READ | PR_POLL_ERR | PR_POLL_HUP) +#else +#define POLL_RECV_FLAGS (PR_POLL_READ | PR_POLL_ERR) #endif - ) { + if (pollset[SSOCK_FD].out_flags & POLL_RECV_FLAGS) { /* Read from socket and write to stdout */ nb = PR_Recv(pollset[SSOCK_FD].fd, buf, sizeof buf, 0, maxInterval); FPRINTF(stderr, "%s: Read from server %d bytes\n", progName, nb); @@ -1553,7 +1569,7 @@ run() if (skipProtoHeader != PR_TRUE || wrStarted == PR_TRUE) { PR_Write(std_out, buf, nb); } else { - separateReqHeader(std_out, buf, nb, &wrStarted, + separateReqHeader(std_out, (char *)buf, nb, &wrStarted, &headerSeparatorPtrnId); } if (verbose) @@ -1567,42 +1583,10 @@ done: if (s) { PR_Close(s); } - - return error; -} - -PRInt32 -ReadFile(const char *filename, char **data) -{ - char *ret = NULL; - char buf[8192]; - unsigned int len = 0; - PRStatus rv; - - PRFileDesc *fd = PR_Open(filename, PR_RDONLY, 0); - if (!fd) - return -1; - - for (;;) { - rv = PR_Read(fd, buf, sizeof(buf)); - if (rv < 0) { - PR_Free(ret); - return rv; - } - - if (!rv) - break; - - ret = PR_Realloc(ret, len + rv); - if (!ret) { - return -1; - } - PORT_Memcpy(ret + len, buf, rv); - len += rv; + if (requestFile) { + PR_Close(pollset[STDIN_FD].fd); } - - *data = ret; - return len; + return error; } int @@ -1667,11 +1651,7 @@ main(int argc, char **argv) break; case 'A': - requestStringLen = ReadFile(optstate->value, &requestString); - if (requestStringLen < 0) { - fprintf(stderr, "Couldn't read file %s\n", optstate->value); - exit(1); - } + requestFile = PORT_Strdup(optstate->value); break; case 'C': @@ -1777,6 +1757,11 @@ main(int argc, char **argv) case 'Z': enableZeroRtt = PR_TRUE; + zeroRttData = PORT_ZAlloc(ZERO_RTT_MAX); + if (!zeroRttData) { + fprintf(stderr, "Unable to allocate buffer for 0-RTT\n"); + exit(1); + } break; case 'a': @@ -2059,20 +2044,13 @@ done: PR_Close(s); } - if (hs1SniHostName) { - PORT_Free(hs1SniHostName); - } - if (hs2SniHostName) { - PORT_Free(hs2SniHostName); - } - if (nickname) { - PORT_Free(nickname); - } - if (pwdata.data) { - PORT_Free(pwdata.data); - } + PORT_Free((void *)requestFile); + PORT_Free(hs1SniHostName); + PORT_Free(hs2SniHostName); + PORT_Free(nickname); + PORT_Free(pwdata.data); PORT_Free(host); - PORT_Free(requestString); + PORT_Free(zeroRttData); if (enabledGroups) { PORT_Free(enabledGroups); diff --git a/lib/ssl/sslsecur.c b/lib/ssl/sslsecur.c index d3424a7ad..a1d389214 100644 --- a/lib/ssl/sslsecur.c +++ b/lib/ssl/sslsecur.c @@ -922,21 +922,30 @@ ssl_SecureSend(sslSocket *ss, const unsigned char *buf, int len, int flags) */ if (!ss->firstHsDone) { PRBool allowEarlySend = PR_FALSE; + PRBool firstClientWrite = PR_FALSE; ssl_Get1stHandshakeLock(ss); - if (ss->opt.enableFalseStart || - (ss->opt.enable0RttData && !ss->sec.isServer)) { + /* The client can sometimes send before the handshake is fully + * complete. In TLS 1.2: false start; in TLS 1.3: 0-RTT. */ + if (!ss->sec.isServer && + (ss->opt.enableFalseStart || ss->opt.enable0RttData)) { ssl_GetSSL3HandshakeLock(ss); - /* The client can sometimes send before the handshake is fully - * complete. In TLS 1.2: false start; in TLS 1.3: 0-RTT. */ zeroRtt = ss->ssl3.hs.zeroRttState == ssl_0rtt_sent || ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted; allowEarlySend = ss->ssl3.hs.canFalseStart || zeroRtt; + firstClientWrite = ss->ssl3.hs.ws == idle_handshake; ssl_ReleaseSSL3HandshakeLock(ss); } if (!allowEarlySend && ss->handshake) { rv = ssl_Do1stHandshake(ss); } + if (firstClientWrite) { + /* Wait until after sending ClientHello and double-check 0-RTT. */ + ssl_GetSSL3HandshakeLock(ss); + zeroRtt = ss->ssl3.hs.zeroRttState == ssl_0rtt_sent || + ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted; + ssl_ReleaseSSL3HandshakeLock(ss); + } ssl_Release1stHandshakeLock(ss); } diff --git a/lib/ssl/sslsock.c b/lib/ssl/sslsock.c index 5b98a3de4..286b35e7d 100644 --- a/lib/ssl/sslsock.c +++ b/lib/ssl/sslsock.c @@ -3039,26 +3039,27 @@ ssl_Poll(PRFileDesc *fd, PRInt16 how_flags, PRInt16 *p_out_flags) } else { /* handshaking as server */ new_flags |= PR_POLL_READ; } - } else + } else if (ss->lastWriteBlocked) { /* First handshake is in progress */ - if (ss->lastWriteBlocked) { if (new_flags & PR_POLL_READ) { /* The caller is waiting for data to be received, ** but the initial handshake is blocked on write, or the ** client's first handshake record has not been written. ** The code should select on write, not read. */ - new_flags ^= PR_POLL_READ; /* don't select on read. */ + new_flags &= ~PR_POLL_READ; /* don't select on read. */ new_flags |= PR_POLL_WRITE; /* do select on write. */ } } else if (new_flags & PR_POLL_WRITE) { /* The caller is trying to write, but the handshake is ** blocked waiting for data to read, and the first ** handshake has been sent. So do NOT to poll on write - ** unless we did false start. + ** unless we did false start or we are doing 0-RTT. */ - if (!ss->ssl3.hs.canFalseStart) { - new_flags ^= PR_POLL_WRITE; /* don't select on write. */ + if (!(ss->ssl3.hs.canFalseStart || + ss->ssl3.hs.zeroRttState == ssl_0rtt_sent || + ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted)) { + new_flags &= ~PR_POLL_WRITE; /* don't select on write. */ } new_flags |= PR_POLL_READ; /* do select on read. */ } @@ -3098,6 +3099,9 @@ ssl_Poll(PRFileDesc *fd, PRInt16 how_flags, PRInt16 *p_out_flags) } } + SSL_TRC(20, ("%d: SSL[%d]: ssl_Poll flags %x -> %x", + SSL_GETPID(), fd, how_flags, new_flags)); + if (new_flags && (fd->lower->methods->poll != NULL)) { PRInt16 lower_out_flags = 0; PRInt16 lower_new_flags; |