summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/resolve/resolved-dns-stream.c2
-rw-r--r--src/resolve/resolved-dnstls-gnutls.c2
-rw-r--r--src/resolve/resolved-dnstls-openssl.c152
-rw-r--r--src/resolve/resolved-dnstls-openssl.h2
-rw-r--r--src/resolve/resolved-dnstls.h2
5 files changed, 142 insertions, 18 deletions
diff --git a/src/resolve/resolved-dns-stream.c b/src/resolve/resolved-dns-stream.c
index faf5e26ba4..8c6f217ad9 100644
--- a/src/resolve/resolved-dns-stream.c
+++ b/src/resolve/resolved-dns-stream.c
@@ -280,7 +280,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
#if ENABLE_DNS_OVER_TLS
if (s->encrypted) {
- r = dnstls_stream_on_io(s);
+ r = dnstls_stream_on_io(s, revents);
if (r == DNSTLS_STREAM_CLOSED)
return 0;
diff --git a/src/resolve/resolved-dnstls-gnutls.c b/src/resolve/resolved-dnstls-gnutls.c
index 5e6a899db8..820e1926fd 100644
--- a/src/resolve/resolved-dnstls-gnutls.c
+++ b/src/resolve/resolved-dnstls-gnutls.c
@@ -77,7 +77,7 @@ void dnstls_stream_free(DnsStream *stream) {
gnutls_deinit(stream->dnstls_data.session);
}
-int dnstls_stream_on_io(DnsStream *stream) {
+int dnstls_stream_on_io(DnsStream *stream, uint32_t revents) {
int r;
assert(stream);
diff --git a/src/resolve/resolved-dnstls-openssl.c b/src/resolve/resolved-dnstls-openssl.c
index d0a1bba773..5dd7737337 100644
--- a/src/resolve/resolved-dnstls-openssl.c
+++ b/src/resolve/resolved-dnstls-openssl.c
@@ -13,31 +13,84 @@
DEFINE_TRIVIAL_CLEANUP_FUNC(SSL*, SSL_free);
DEFINE_TRIVIAL_CLEANUP_FUNC(BIO*, BIO_free);
+static int dnstls_flush_write_buffer(DnsStream *stream) {
+ ssize_t ss;
+
+ assert(stream);
+ assert(stream->encrypted);
+
+ if (stream->dnstls_data.write_buffer->length > 0) {
+ assert(stream->dnstls_data.write_buffer->data);
+
+ struct iovec iov[1];
+ iov[0].iov_base = stream->dnstls_data.write_buffer->data;
+ iov[0].iov_len = stream->dnstls_data.write_buffer->length;
+ ss = dns_stream_writev(stream, iov, 1, DNS_STREAM_WRITE_TLS_DATA);
+ if (ss < 0) {
+ if (ss == -EAGAIN)
+ stream->dnstls_events |= EPOLLOUT;
+
+ return ss;
+ } else {
+ stream->dnstls_data.write_buffer->length -= ss;
+ stream->dnstls_data.write_buffer->data += ss;
+
+ if (stream->dnstls_data.write_buffer->length > 0) {
+ stream->dnstls_events |= EPOLLOUT;
+ return -EAGAIN;
+ }
+ }
+ }
+
+ return 0;
+}
+
int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server) {
_cleanup_(SSL_freep) SSL *s = NULL;
- _cleanup_(BIO_freep) BIO *b = NULL;
+ _cleanup_(BIO_freep) BIO *rb = NULL;
+ _cleanup_(BIO_freep) BIO *wb = NULL;
+ int r;
+ int error;
assert(stream);
assert(server);
- b = BIO_new_socket(stream->fd, 0);
- if (!b)
+ rb = BIO_new_socket(stream->fd, 0);
+ if (!rb)
+ return -ENOMEM;
+
+ wb = BIO_new(BIO_s_mem());
+ if (!wb)
return -ENOMEM;
+ BIO_get_mem_ptr(wb, &stream->dnstls_data.write_buffer);
+
s = SSL_new(server->dnstls_data.ctx);
if (!s)
return -ENOMEM;
SSL_set_connect_state(s);
- SSL_set_bio(s, b, b);
- b = NULL;
+ SSL_set_session(s, server->dnstls_data.session);
+ SSL_set_bio(s, TAKE_PTR(rb), TAKE_PTR(wb));
+
+ stream->dnstls_data.handshake = SSL_do_handshake(s);
+ if (stream->dnstls_data.handshake <= 0) {
+ error = SSL_get_error(s, stream->dnstls_data.handshake);
+ if (!IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
+ char errbuf[256];
- /* DNS-over-TLS using OpenSSL doesn't support TCP Fast Open yet */
- connect(stream->fd, &stream->tfo_address.sa, stream->tfo_salen);
- stream->tfo_salen = 0;
+ ERR_error_string_n(error, errbuf, sizeof(errbuf));
+ log_debug("Failed to invoke SSL_do_handshake: %s", errbuf);
+ return -ECONNREFUSED;
+ }
+ }
stream->encrypted = true;
- stream->dnstls_events = EPOLLOUT;
+
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0 && r != -EAGAIN)
+ return r;
+
stream->dnstls_data.ssl = TAKE_PTR(s);
return 0;
@@ -51,7 +104,7 @@ void dnstls_stream_free(DnsStream *stream) {
SSL_free(stream->dnstls_data.ssl);
}
-int dnstls_stream_on_io(DnsStream *stream) {
+int dnstls_stream_on_io(DnsStream *stream, uint32_t revents) {
int r;
int error;
@@ -59,14 +112,25 @@ int dnstls_stream_on_io(DnsStream *stream) {
assert(stream->encrypted);
assert(stream->dnstls_data.ssl);
+ /* Flush write buffer when requested by OpenSSL ss*/
+ if ((revents & EPOLLOUT) && (stream->dnstls_events & EPOLLOUT)) {
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+ }
+
if (stream->dnstls_data.shutdown) {
r = SSL_shutdown(stream->dnstls_data.ssl);
- if (r == 0)
- return -EAGAIN;
- else if (r < 0) {
+ if (r <= 0) {
error = SSL_get_error(stream->dnstls_data.ssl, r);
- if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
- stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
+ if (r == 0 || IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
+ if (r < 0)
+ stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
+
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
return -EAGAIN;
} else {
char errbuf[256];
@@ -76,6 +140,10 @@ int dnstls_stream_on_io(DnsStream *stream) {
}
}
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
stream->dnstls_events = 0;
stream->dnstls_data.shutdown = false;
dns_stream_unref(stream);
@@ -86,6 +154,10 @@ int dnstls_stream_on_io(DnsStream *stream) {
error = SSL_get_error(stream->dnstls_data.ssl, stream->dnstls_data.handshake);
if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
return -EAGAIN;
} else {
char errbuf[256];
@@ -97,6 +169,9 @@ int dnstls_stream_on_io(DnsStream *stream) {
}
stream->dnstls_events = 0;
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
}
return 0;
@@ -111,6 +186,16 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) {
assert(stream->encrypted);
assert(stream->dnstls_data.ssl);
+ if (stream->server) {
+ s = SSL_get1_session(stream->dnstls_data.ssl);
+ if (s) {
+ if (stream->server->dnstls_data.session)
+ SSL_SESSION_free(stream->server->dnstls_data.session);
+
+ stream->server->dnstls_data.session = s;
+ }
+ }
+
if (error == ETIMEDOUT) {
r = SSL_shutdown(stream->dnstls_data.ssl);
if (r == 0) {
@@ -118,11 +203,20 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) {
stream->dnstls_data.shutdown = true;
dns_stream_ref(stream);
}
+
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
return -EAGAIN;
} else if (r < 0) {
ssl_error = SSL_get_error(stream->dnstls_data.ssl, r);
if (IN_SET(ssl_error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
stream->dnstls_events = ssl_error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0 && r != -EAGAIN)
+ return r;
+
if (!stream->dnstls_data.shutdown) {
stream->dnstls_data.shutdown = true;
dns_stream_ref(stream);
@@ -135,6 +229,11 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) {
log_debug("Failed to invoke SSL_shutdown: %s", errbuf);
}
}
+
+ stream->dnstls_events = 0;
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
}
return 0;
@@ -155,6 +254,10 @@ ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
error = SSL_get_error(stream->dnstls_data.ssl, ss);
if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
ss = -EAGAIN;
} else {
char errbuf[256];
@@ -166,6 +269,10 @@ ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
}
stream->dnstls_events = 0;
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
return ss;
}
@@ -184,6 +291,12 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
error = SSL_get_error(stream->dnstls_data.ssl, ss);
if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
+
+ /* flush write buffer in cache of renegotiation */
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
ss = -EAGAIN;
} else {
char errbuf[256];
@@ -195,6 +308,12 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
}
stream->dnstls_events = 0;
+
+ /* flush write buffer in cache of renegotiation */
+ r = dnstls_flush_write_buffer(stream);
+ if (r < 0)
+ return r;
+
return ss;
}
@@ -213,4 +332,7 @@ void dnstls_server_free(DnsServer *server) {
if (server->dnstls_data.ctx)
SSL_CTX_free(server->dnstls_data.ctx);
+
+ if (server->dnstls_data.session)
+ SSL_SESSION_free(server->dnstls_data.session);
}
diff --git a/src/resolve/resolved-dnstls-openssl.h b/src/resolve/resolved-dnstls-openssl.h
index c92d2b2354..c57bc1c57c 100644
--- a/src/resolve/resolved-dnstls-openssl.h
+++ b/src/resolve/resolved-dnstls-openssl.h
@@ -11,10 +11,12 @@
struct DnsTlsServerData {
SSL_CTX *ctx;
+ SSL_SESSION *session;
};
struct DnsTlsStreamData {
int handshake;
bool shutdown;
SSL *ssl;
+ BUF_MEM *write_buffer;
};
diff --git a/src/resolve/resolved-dnstls.h b/src/resolve/resolved-dnstls.h
index 52af3e9801..fdd85eece6 100644
--- a/src/resolve/resolved-dnstls.h
+++ b/src/resolve/resolved-dnstls.h
@@ -23,7 +23,7 @@ typedef struct DnsTlsStreamData DnsTlsStreamData;
int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server);
void dnstls_stream_free(DnsStream *stream);
-int dnstls_stream_on_io(DnsStream *stream);
+int dnstls_stream_on_io(DnsStream *stream, uint32_t revents);
int dnstls_stream_shutdown(DnsStream *stream, int error);
ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count);
ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count);