diff options
author | Yu Watanabe <watanabe.yu+github@gmail.com> | 2022-02-02 07:06:56 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-02 07:06:56 +0900 |
commit | e0ec97514835163ea28786669debeb56228faf2a (patch) | |
tree | 4eaaca7ea35c7e3e5fd7a9833736ab024faa2abc | |
parent | 23b1e8d087c9e8c5a2cdcc6a91510a4e7ca8f72f (diff) | |
parent | c76120f1b82f7e1c6a53b1569087db462c21b7d1 (diff) | |
download | systemd-e0ec97514835163ea28786669debeb56228faf2a.tar.gz |
Merge pull request #22327 from joanbm/main_resolved_improvements
resolved: misc. small DnsStream refactors and improvements
-rw-r--r-- | src/resolve/resolved-dns-stream.c | 98 | ||||
-rw-r--r-- | src/resolve/resolved-dns-stream.h | 3 | ||||
-rw-r--r-- | src/resolve/resolved-dnstls-gnutls.c | 30 | ||||
-rw-r--r-- | src/resolve/resolved-dnstls-openssl.c | 48 | ||||
-rw-r--r-- | src/resolve/resolved-dnstls.h | 5 | ||||
-rw-r--r-- | src/resolve/test-resolved-stream.c | 34 |
6 files changed, 108 insertions, 110 deletions
diff --git a/src/resolve/resolved-dns-stream.c b/src/resolve/resolved-dns-stream.c index cf9d1a9d5e..61e92bea83 100644 --- a/src/resolve/resolved-dns-stream.c +++ b/src/resolve/resolved-dns-stream.c @@ -27,7 +27,7 @@ static void dns_stream_stop(DnsStream *s) { } static int dns_stream_update_io(DnsStream *s) { - int f = 0; + uint32_t f = 0; assert(s); @@ -47,6 +47,8 @@ static int dns_stream_update_io(DnsStream *s) { set_size(s->queries) < DNS_QUERIES_PER_STREAM) f |= EPOLLIN; + s->requested_events = f; + #if ENABLE_DNS_OVER_TLS /* For handshake and clean closing purposes, TLS can override requested events */ if (s->dnstls_events != 0) @@ -208,22 +210,10 @@ ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt, assert(iov); #if ENABLE_DNS_OVER_TLS - if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) { - ssize_t ss; - size_t i; - - m = 0; - for (i = 0; i < iovcnt; i++) { - ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len); - if (ss < 0) - return ss; - - m += ss; - if (ss != (ssize_t) iov[i].iov_len) - continue; - } - } else + if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) + return dnstls_stream_writev(s, iov, iovcnt); #endif + if (s->tfo_salen > 0) { struct msghdr hdr = { .msg_iov = (struct iovec*) iov, @@ -289,7 +279,7 @@ static DnsPacket *dns_stream_take_read_packet(DnsStream *s) { * Even this makes a room to read in the stream, this does not call dns_stream_update(), hence * EPOLLIN flag is not set automatically. So, to read further packets from the stream, * dns_stream_update() must be called explicitly. Currently, this is only called from - * on_stream_io_impl(), and there dns_stream_update() is called. */ + * on_stream_io(), and there dns_stream_update() is called. */ if (!s->read_packet) return NULL; @@ -304,15 +294,13 @@ static DnsPacket *dns_stream_take_read_packet(DnsStream *s) { return TAKE_PTR(s->read_packet); } -static int on_stream_io_impl(DnsStream *s, uint32_t revents) { +static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) { + _cleanup_(dns_stream_unrefp) DnsStream *s = dns_stream_ref(userdata); /* Protect stream while we process it */ bool progressed = false; int r; assert(s); - /* This returns 1 when possible remaining stream exists, 0 on completed - stream or recoverable error, and negative errno on failure. */ - #if ENABLE_DNS_OVER_TLS if (s->encrypted) { r = dnstls_stream_on_io(s, revents); @@ -364,9 +352,9 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) { } } - if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) && - (!s->read_packet || - s->n_read < sizeof(s->read_size) + s->read_packet->size)) { + while ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) && + (!s->read_packet || + s->n_read < sizeof(s->read_size) + s->read_packet->size)) { if (s->n_read < sizeof(s->read_size)) { ssize_t ss; @@ -375,6 +363,7 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) { if (ss < 0) { if (!ERRNO_IS_TRANSIENT(ss)) return dns_stream_complete(s, -ss); + break; } else if (ss == 0) return dns_stream_complete(s, ECONNRESET); else { @@ -428,6 +417,7 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) { if (ss < 0) { if (!ERRNO_IS_TRANSIENT(ss)) return dns_stream_complete(s, -ss); + break; } else if (ss == 0) return dns_stream_complete(s, ECONNRESET); else @@ -448,23 +438,19 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) { return dns_stream_complete(s, -r); s->packet_received = true; + + /* If we just disabled the read event, stop reading */ + if (!FLAGS_SET(s->requested_events, EPOLLIN)) + break; } } } - if (s->type == DNS_STREAM_LLMNR_SEND && s->packet_received) { - uint32_t events; - - /* Complete the stream if finished reading and writing one packet, and there's nothing - * else left to write. */ - - r = sd_event_source_get_io_events(s->io_event_source, &events); - if (r < 0) - return r; - - if (!FLAGS_SET(events, EPOLLOUT)) - return dns_stream_complete(s, 0); - } + /* Complete the stream if finished reading and writing one packet, and there's nothing + * else left to write. */ + if (s->type == DNS_STREAM_LLMNR_SEND && s->packet_received && + !FLAGS_SET(s->requested_events, EPOLLOUT)) + return dns_stream_complete(s, 0); /* If we did something, let's restart the timeout event source */ if (progressed && s->timeout_event_source) { @@ -473,44 +459,6 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) { log_warning_errno(errno, "Couldn't restart TCP connection timeout, ignoring: %m"); } - return 1; -} - -static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) { - _cleanup_(dns_stream_unrefp) DnsStream *s = dns_stream_ref(userdata); /* Protect stream while we process it */ - int r; - - assert(s); - - r = on_stream_io_impl(s, revents); - if (r <= 0) - return r; - -#if ENABLE_DNS_OVER_TLS - if (!s->encrypted) - return 0; - - /* When using DNS-over-TLS, the underlying TLS library may read the entire TLS record - and buffer it internally. If this happens, we will not receive further EPOLLIN events, - and unless there's some unrelated activity on the socket, we will hang until time out. - To avoid this, if there's buffered TLS data, generate a "fake" EPOLLIN event. - This is hacky, but it makes this case transparent to the rest of the IO code. */ - while (dnstls_stream_has_buffered_data(s)) { - uint32_t events; - - /* Make sure the stream still wants to process more data... */ - r = sd_event_source_get_io_events(s->io_event_source, &events); - if (r < 0) - return r; - if (!FLAGS_SET(events, EPOLLIN)) - break; - - r = on_stream_io_impl(s, EPOLLIN); - if (r <= 0) - return r; - } -#endif - return 0; } diff --git a/src/resolve/resolved-dns-stream.h b/src/resolve/resolved-dns-stream.h index 1c606365cd..ba4a59e41c 100644 --- a/src/resolve/resolved-dns-stream.h +++ b/src/resolve/resolved-dns-stream.h @@ -61,6 +61,7 @@ struct DnsStream { uint32_t ttl; bool identified; bool packet_received; /* At least one packet is received. Used by LLMNR. */ + uint32_t requested_events; /* only when using TCP fast open */ union sockaddr_union tfo_address; @@ -68,7 +69,7 @@ struct DnsStream { #if ENABLE_DNS_OVER_TLS DnsTlsStreamData dnstls_data; - int dnstls_events; + uint32_t dnstls_events; #endif sd_event_source *io_event_source; diff --git a/src/resolve/resolved-dnstls-gnutls.c b/src/resolve/resolved-dnstls-gnutls.c index 8610cacab6..8c8628ebbb 100644 --- a/src/resolve/resolved-dnstls-gnutls.c +++ b/src/resolve/resolved-dnstls-gnutls.c @@ -6,6 +6,7 @@ #include <gnutls/socket.h> +#include "io-util.h" #include "resolved-dns-stream.h" #include "resolved-dnstls.h" #include "resolved-manager.h" @@ -13,7 +14,7 @@ #define TLS_PROTOCOL_PRIORITY "NORMAL:-VERS-ALL:+VERS-TLS1.3:+VERS-TLS1.2" DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(gnutls_session_t, gnutls_deinit, NULL); -static ssize_t dnstls_stream_writev(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) { +static ssize_t dnstls_stream_vec_push(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) { int r; assert(p); @@ -81,7 +82,7 @@ int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server) { gnutls_handshake_set_timeout(gs, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); gnutls_transport_set_ptr2(gs, (gnutls_transport_ptr_t) (long) stream->fd, stream); - gnutls_transport_set_vec_push_function(gs, &dnstls_stream_writev); + gnutls_transport_set_vec_push_function(gs, &dnstls_stream_vec_push); stream->encrypted = true; stream->dnstls_data.handshake = gnutls_handshake(gs); @@ -163,15 +164,26 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) { return 0; } -ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) { +ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) { ssize_t ss; assert(stream); assert(stream->encrypted); assert(stream->dnstls_data.session); - assert(buf); + assert(iov); + assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0); + + gnutls_record_cork(stream->dnstls_data.session); - ss = gnutls_record_send(stream->dnstls_data.session, buf, count); + for (size_t i = 0; i < iovcnt; i++) { + ss = gnutls_record_send( + stream->dnstls_data.session, + iov[i].iov_base, iov[i].iov_len); + if (ss < 0) + break; + } + + ss = gnutls_record_uncork(stream->dnstls_data.session, 0); if (ss < 0) switch(ss) { case GNUTLS_E_INTERRUPTED: @@ -211,14 +223,6 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) { return ss; } -bool dnstls_stream_has_buffered_data(DnsStream *stream) { - assert(stream); - assert(stream->encrypted); - assert(stream->dnstls_data.session); - - return gnutls_record_check_pending(stream->dnstls_data.session) > 0; -} - void dnstls_server_free(DnsServer *server) { assert(server); diff --git a/src/resolve/resolved-dnstls-openssl.c b/src/resolve/resolved-dnstls-openssl.c index 7d264dd367..4d3a88c8da 100644 --- a/src/resolve/resolved-dnstls-openssl.c +++ b/src/resolve/resolved-dnstls-openssl.c @@ -292,15 +292,10 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) { return 0; } -ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) { +static ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) { int error, r; ssize_t ss; - assert(stream); - assert(stream->encrypted); - assert(stream->dnstls_data.ssl); - assert(buf); - ERR_clear_error(); ss = r = SSL_write(stream->dnstls_data.ssl, buf, count); if (r <= 0) { @@ -329,6 +324,29 @@ ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) { return ss; } +ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) { + _cleanup_free_ char *buf = NULL; + size_t count; + + assert(stream); + assert(stream->encrypted); + assert(stream->dnstls_data.ssl); + assert(iov); + assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0); + + if (iovcnt == 1) + return dnstls_stream_write(stream, iov[0].iov_base, iov[0].iov_len); + + /* As of now, OpenSSL can not accumulate multiple writes, so join into a + single buffer. Suboptimal, but better than multiple SSL_write calls. */ + count = IOVEC_TOTAL_SIZE(iov, iovcnt); + buf = new(char, count); + for (size_t i = 0, pos = 0; i < iovcnt; pos += iov[i].iov_len, i++) + memcpy(buf + pos, iov[i].iov_base, iov[i].iov_len); + + return dnstls_stream_write(stream, buf, count); +} + ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) { int error, r; ssize_t ss; @@ -343,7 +361,15 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) { if (r <= 0) { error = SSL_get_error(stream->dnstls_data.ssl, r); if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) { - stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT; + /* If we receive SSL_ERROR_WANT_READ here, there are two possible scenarios: + * OpenSSL needs to renegotiate (so we want to get an EPOLLIN event), or + * There is no more application data is available, so we can just return + And apparently there's no nice way to distinguish between the two. + To handle this, never set EPOLLIN and just continue as usual. + If OpenSSL really wants to read due to renegotiation, it will tell us + again on SSL_write (at which point we will request EPOLLIN force a read); + or we will just eventually read data anyway while we wait for a packet */ + stream->dnstls_events = error == SSL_ERROR_WANT_READ ? 0 : EPOLLOUT; ss = -EAGAIN; } else if (error == SSL_ERROR_ZERO_RETURN) { stream->dnstls_events = 0; @@ -367,14 +393,6 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) { return ss; } -bool dnstls_stream_has_buffered_data(DnsStream *stream) { - assert(stream); - assert(stream->encrypted); - assert(stream->dnstls_data.ssl); - - return SSL_has_pending(stream->dnstls_data.ssl) > 0; -} - void dnstls_server_free(DnsServer *server) { assert(server); diff --git a/src/resolve/resolved-dnstls.h b/src/resolve/resolved-dnstls.h index ed214dc6c4..cda97e0b12 100644 --- a/src/resolve/resolved-dnstls.h +++ b/src/resolve/resolved-dnstls.h @@ -3,8 +3,8 @@ #if ENABLE_DNS_OVER_TLS -#include <stdbool.h> #include <stdint.h> +#include <sys/uio.h> typedef struct DnsServer DnsServer; typedef struct DnsStream DnsStream; @@ -27,9 +27,8 @@ int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server); void dnstls_stream_free(DnsStream *stream); int dnstls_stream_on_io(DnsStream *stream, uint32_t revents); int dnstls_stream_shutdown(DnsStream *stream, int error); -ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count); +ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt); ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count); -bool dnstls_stream_has_buffered_data(DnsStream *stream); void dnstls_server_free(DnsServer *server); diff --git a/src/resolve/test-resolved-stream.c b/src/resolve/test-resolved-stream.c index f9428989f0..beaa855384 100644 --- a/src/resolve/test-resolved-stream.c +++ b/src/resolve/test-resolved-stream.c @@ -2,10 +2,12 @@ #include <arpa/inet.h> #include <fcntl.h> +#include <net/if.h> #include <pthread.h> #include <signal.h> #include <stdlib.h> #include <string.h> +#include <sys/ioctl.h> #include <sys/prctl.h> #include <sys/socket.h> #include <sys/wait.h> @@ -13,6 +15,7 @@ #include "fd-util.h" #include "log.h" +#include "macro.h" #include "process-util.h" #include "resolved-dns-packet.h" #include "resolved-dns-question.h" @@ -144,7 +147,7 @@ static void *tls_dns_server(void *p) { r = safe_fork_full("(test-resolved-stream-tls-openssl)", (int[]) { fd_server, fd_tls }, 2, FORK_RESET_SIGNALS|FORK_CLOSE_ALL_FDS|FORK_DEATHSIG|FORK_LOG|FORK_REOPEN_LOG, &openssl_pid); - assert(r >= 0); + assert_se(r >= 0); if (r == 0) { /* Child */ assert_se(dup2(fd_tls, STDIN_FILENO) >= 0); @@ -200,6 +203,10 @@ static int on_stream_packet(DnsStream *stream, DnsPacket *p) { return 0; } +static int on_stream_complete_do_nothing(DnsStream *s, int error) { + return 0; +} + static void test_dns_stream(bool tls) { Manager manager = {}; _cleanup_(dns_stream_unrefp) DnsStream *stream = NULL; @@ -251,9 +258,10 @@ static void test_dns_stream(bool tls) { /* systemd-resolved uses (and requires) the socket to be in nonblocking mode */ assert_se(fcntl(clientfd, F_SETFL, O_NONBLOCK) >= 0); - /* Initialize DNS stream */ + /* Initialize DNS stream (disabling the default self-destruction + behaviour when no complete callback is set) */ assert_se(dns_stream_new(&manager, &stream, DNS_STREAM_LOOKUP, DNS_PROTOCOL_DNS, - TAKE_FD(clientfd), NULL, on_stream_packet, NULL, + TAKE_FD(clientfd), NULL, on_stream_packet, on_stream_complete_do_nothing, DNS_STREAM_DEFAULT_TIMEOUT_USEC) >= 0); #if ENABLE_DNS_OVER_TLS if (tls) { @@ -322,6 +330,24 @@ static void test_dns_stream(bool tls) { log_info("test-resolved-stream: Finished %s test", tls ? "TLS" : "TCP"); } +static void try_isolate_network(void) { + _cleanup_close_ int socket_fd = -1; + + if (unshare(CLONE_NEWUSER | CLONE_NEWNET) < 0) { + log_warning("test-resolved-stream: Can't create user and network ns, running on host"); + return; + } + + /* Bring up the loopback interfaceon the newly created network namespace */ + struct ifreq req = { .ifr_ifindex = 1 }; + assert_se((socket_fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0)) >= 0); + assert_se(ioctl(socket_fd,SIOCGIFNAME,&req) >= 0); + assert_se(ioctl(socket_fd, SIOCGIFFLAGS, &req) >= 0); + assert_se(FLAGS_SET(req.ifr_flags, IFF_LOOPBACK)); + req.ifr_flags |= IFF_UP; + assert_se(ioctl(socket_fd, SIOCSIFFLAGS, &req) >= 0); +} + int main(int argc, char **argv) { SERVER_ADDRESS = (struct sockaddr_in) { .sin_family = AF_INET, @@ -331,6 +357,8 @@ int main(int argc, char **argv) { test_setup_logging(LOG_DEBUG); + try_isolate_network(); + test_dns_stream(false); #if ENABLE_DNS_OVER_TLS if (system("openssl version >/dev/null 2>&1") != 0) |