From d0681e2eacbe08c4b02de798cddffcae2c003a61 Mon Sep 17 00:00:00 2001 From: Dennis Jackson Date: Thu, 16 Jun 2022 11:22:49 +0000 Subject: Bug 1617956 - Add support for asynchronous client auth hooks. r=mt Differential Revision: https://phabricator.services.mozilla.com/D138149 --- cmd/tstclnt/tstclnt.c | 72 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) (limited to 'cmd') diff --git a/cmd/tstclnt/tstclnt.c b/cmd/tstclnt/tstclnt.c index 453842b16..cbf824ec1 100644 --- a/cmd/tstclnt/tstclnt.c +++ b/cmd/tstclnt/tstclnt.c @@ -268,7 +268,7 @@ PrintParameterUsage() fprintf(stderr, "%-20s Send TLS_FALLBACK_SCSV\n", "-K"); fprintf(stderr, "%-20s Prints only payload data. Skips HTTP header.\n", "-S"); fprintf(stderr, "%-20s Client speaks first. \n", "-f"); - fprintf(stderr, "%-20s Use synchronous certificate validation\n", "-O"); + fprintf(stderr, "%-20s Use synchronous certificate selection & validation\n", "-O"); fprintf(stderr, "%-20s Override bad server cert. Make it OK.\n", "-o"); fprintf(stderr, "%-20s Disable SSL socket locking.\n", "-s"); fprintf(stderr, "%-20s Verbose progress reporting.\n", "-v"); @@ -735,6 +735,43 @@ ownAuthCertificate(void *arg, PRFileDesc *fd, PRBool checkSig, return SECWouldBlock; } +struct clientCertAsyncParamsStr { + void *arg; /* The nickname used for selection, not owned */ + struct CERTDistNamesStr *caNames; /* CA Names specified by Server, owned. */ +}; + +/* tstclnt can only have a single handshake in progress at any instant. */ +PRBool clientCertAsyncSelect = PR_TRUE; /* Async by default */ +PRBool clientCertIsBlocked = PR_FALSE; /* Whether we waiting to finish ClientAuth */ +struct clientCertAsyncParamsStr *clientCertParams = NULL; + +SECStatus +own_CompleteClientAuthData(PRFileDesc *fd, struct clientCertAsyncParamsStr *args) +{ + SECStatus rv; + CERTCertificate *pRetCert = NULL; + SECKEYPrivateKey *pRetKey = NULL; + rv = NSS_GetClientAuthData(args->arg, fd, args->caNames, &pRetCert, &pRetKey); + if (rv != SECSuccess) { + fprintf(stderr, "Failed to load a suitable client certificate \n"); + } + return SSL_ClientCertCallbackComplete(fd, rv, pRetKey, pRetCert); +} + +SECStatus +restartHandshakeAfterClientCertIfNeeded(PRFileDesc *fd) +{ + if (!clientCertIsBlocked) { + return SECFailure; + } + clientCertIsBlocked = PR_FALSE; + own_CompleteClientAuthData(fd, clientCertParams); + CERT_FreeDistNames(clientCertParams->caNames); + PORT_Free(clientCertParams); + clientCertParams = NULL; + return SECSuccess; +} + SECStatus own_GetClientAuthData(void *arg, PRFileDesc *socket, @@ -742,6 +779,26 @@ own_GetClientAuthData(void *arg, struct CERTCertificateStr **pRetCert, struct SECKEYPrivateKeyStr **pRetKey) { + if (clientCertAsyncSelect) { + PR_ASSERT(!clientCertIsBlocked); + PR_ASSERT(!clientCertParams); + + clientCertIsBlocked = PR_TRUE; + clientCertParams = PORT_Alloc(sizeof(struct clientCertAsyncParamsStr)); + if (!clientCertParams) { + fprintf(stderr, "Unable to allocate buffer for client cert callback\n"); + exit(1); + } + + clientCertParams->arg = arg; + clientCertParams->caNames = caNames ? CERT_DupDistNames(caNames) : NULL; + if (caNames && !clientCertParams->caNames) { + fprintf(stderr, "Unable to allocate buffer for client cert callback\n"); + exit(1); + } + return SECWouldBlock; + } + if (verbose > 1) { SECStatus rv; fprintf(stderr, "Server requested Client Authentication\n"); @@ -944,7 +1001,7 @@ restartHandshakeAfterServerCertIfNeeded(PRFileDesc *fd, SECStatus rv; PRErrorCode error = 0; - if (!serverCertAuth->isPaused) + if (!serverCertAuth->isPaused || clientCertIsBlocked) return SECSuccess; FPRINTF(stderr, "%s: handshake was paused by auth certificate hook\n", @@ -1065,6 +1122,11 @@ writeBytesToServer(PRFileDesc *s, const PRUint8 *buf, int nb) return EXIT_CODE_HANDSHAKE_FAILED; } + rv = restartHandshakeAfterClientCertIfNeeded(s); + if (rv == SECSuccess) { + continue; + } + pollDesc.in_flags = PR_POLL_WRITE | PR_POLL_EXCEPT; pollDesc.out_flags = 0; FPRINTF(stderr, @@ -1715,6 +1777,11 @@ run() goto done; } + rv = restartHandshakeAfterClientCertIfNeeded(s); + if (rv == SECSuccess) { + continue; + } + pollset[SSOCK_FD].out_flags = 0; pollset[STDIN_FD].out_flags = 0; @@ -1912,6 +1979,7 @@ main(int argc, char **argv) break; case 'O': + clientCertAsyncSelect = PR_FALSE; serverCertAuth.shouldPause = PR_FALSE; break; -- cgit v1.2.1