summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYu Watanabe <watanabe.yu+github@gmail.com>2022-02-02 07:06:56 +0900
committerGitHub <noreply@github.com>2022-02-02 07:06:56 +0900
commite0ec97514835163ea28786669debeb56228faf2a (patch)
tree4eaaca7ea35c7e3e5fd7a9833736ab024faa2abc
parent23b1e8d087c9e8c5a2cdcc6a91510a4e7ca8f72f (diff)
parentc76120f1b82f7e1c6a53b1569087db462c21b7d1 (diff)
downloadsystemd-e0ec97514835163ea28786669debeb56228faf2a.tar.gz
Merge pull request #22327 from joanbm/main_resolved_improvements
resolved: misc. small DnsStream refactors and improvements
-rw-r--r--src/resolve/resolved-dns-stream.c98
-rw-r--r--src/resolve/resolved-dns-stream.h3
-rw-r--r--src/resolve/resolved-dnstls-gnutls.c30
-rw-r--r--src/resolve/resolved-dnstls-openssl.c48
-rw-r--r--src/resolve/resolved-dnstls.h5
-rw-r--r--src/resolve/test-resolved-stream.c34
6 files changed, 108 insertions, 110 deletions
diff --git a/src/resolve/resolved-dns-stream.c b/src/resolve/resolved-dns-stream.c
index cf9d1a9d5e..61e92bea83 100644
--- a/src/resolve/resolved-dns-stream.c
+++ b/src/resolve/resolved-dns-stream.c
@@ -27,7 +27,7 @@ static void dns_stream_stop(DnsStream *s) {
}
static int dns_stream_update_io(DnsStream *s) {
- int f = 0;
+ uint32_t f = 0;
assert(s);
@@ -47,6 +47,8 @@ static int dns_stream_update_io(DnsStream *s) {
set_size(s->queries) < DNS_QUERIES_PER_STREAM)
f |= EPOLLIN;
+ s->requested_events = f;
+
#if ENABLE_DNS_OVER_TLS
/* For handshake and clean closing purposes, TLS can override requested events */
if (s->dnstls_events != 0)
@@ -208,22 +210,10 @@ ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt,
assert(iov);
#if ENABLE_DNS_OVER_TLS
- if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) {
- ssize_t ss;
- size_t i;
-
- m = 0;
- for (i = 0; i < iovcnt; i++) {
- ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len);
- if (ss < 0)
- return ss;
-
- m += ss;
- if (ss != (ssize_t) iov[i].iov_len)
- continue;
- }
- } else
+ if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA))
+ return dnstls_stream_writev(s, iov, iovcnt);
#endif
+
if (s->tfo_salen > 0) {
struct msghdr hdr = {
.msg_iov = (struct iovec*) iov,
@@ -289,7 +279,7 @@ static DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
* Even this makes a room to read in the stream, this does not call dns_stream_update(), hence
* EPOLLIN flag is not set automatically. So, to read further packets from the stream,
* dns_stream_update() must be called explicitly. Currently, this is only called from
- * on_stream_io_impl(), and there dns_stream_update() is called. */
+ * on_stream_io(), and there dns_stream_update() is called. */
if (!s->read_packet)
return NULL;
@@ -304,15 +294,13 @@ static DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
return TAKE_PTR(s->read_packet);
}
-static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
+static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
+ _cleanup_(dns_stream_unrefp) DnsStream *s = dns_stream_ref(userdata); /* Protect stream while we process it */
bool progressed = false;
int r;
assert(s);
- /* This returns 1 when possible remaining stream exists, 0 on completed
- stream or recoverable error, and negative errno on failure. */
-
#if ENABLE_DNS_OVER_TLS
if (s->encrypted) {
r = dnstls_stream_on_io(s, revents);
@@ -364,9 +352,9 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
}
}
- if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
- (!s->read_packet ||
- s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
+ while ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
+ (!s->read_packet ||
+ s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
if (s->n_read < sizeof(s->read_size)) {
ssize_t ss;
@@ -375,6 +363,7 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
if (ss < 0) {
if (!ERRNO_IS_TRANSIENT(ss))
return dns_stream_complete(s, -ss);
+ break;
} else if (ss == 0)
return dns_stream_complete(s, ECONNRESET);
else {
@@ -428,6 +417,7 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
if (ss < 0) {
if (!ERRNO_IS_TRANSIENT(ss))
return dns_stream_complete(s, -ss);
+ break;
} else if (ss == 0)
return dns_stream_complete(s, ECONNRESET);
else
@@ -448,23 +438,19 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
return dns_stream_complete(s, -r);
s->packet_received = true;
+
+ /* If we just disabled the read event, stop reading */
+ if (!FLAGS_SET(s->requested_events, EPOLLIN))
+ break;
}
}
}
- if (s->type == DNS_STREAM_LLMNR_SEND && s->packet_received) {
- uint32_t events;
-
- /* Complete the stream if finished reading and writing one packet, and there's nothing
- * else left to write. */
-
- r = sd_event_source_get_io_events(s->io_event_source, &events);
- if (r < 0)
- return r;
-
- if (!FLAGS_SET(events, EPOLLOUT))
- return dns_stream_complete(s, 0);
- }
+ /* Complete the stream if finished reading and writing one packet, and there's nothing
+ * else left to write. */
+ if (s->type == DNS_STREAM_LLMNR_SEND && s->packet_received &&
+ !FLAGS_SET(s->requested_events, EPOLLOUT))
+ return dns_stream_complete(s, 0);
/* If we did something, let's restart the timeout event source */
if (progressed && s->timeout_event_source) {
@@ -473,44 +459,6 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
log_warning_errno(errno, "Couldn't restart TCP connection timeout, ignoring: %m");
}
- return 1;
-}
-
-static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
- _cleanup_(dns_stream_unrefp) DnsStream *s = dns_stream_ref(userdata); /* Protect stream while we process it */
- int r;
-
- assert(s);
-
- r = on_stream_io_impl(s, revents);
- if (r <= 0)
- return r;
-
-#if ENABLE_DNS_OVER_TLS
- if (!s->encrypted)
- return 0;
-
- /* When using DNS-over-TLS, the underlying TLS library may read the entire TLS record
- and buffer it internally. If this happens, we will not receive further EPOLLIN events,
- and unless there's some unrelated activity on the socket, we will hang until time out.
- To avoid this, if there's buffered TLS data, generate a "fake" EPOLLIN event.
- This is hacky, but it makes this case transparent to the rest of the IO code. */
- while (dnstls_stream_has_buffered_data(s)) {
- uint32_t events;
-
- /* Make sure the stream still wants to process more data... */
- r = sd_event_source_get_io_events(s->io_event_source, &events);
- if (r < 0)
- return r;
- if (!FLAGS_SET(events, EPOLLIN))
- break;
-
- r = on_stream_io_impl(s, EPOLLIN);
- if (r <= 0)
- return r;
- }
-#endif
-
return 0;
}
diff --git a/src/resolve/resolved-dns-stream.h b/src/resolve/resolved-dns-stream.h
index 1c606365cd..ba4a59e41c 100644
--- a/src/resolve/resolved-dns-stream.h
+++ b/src/resolve/resolved-dns-stream.h
@@ -61,6 +61,7 @@ struct DnsStream {
uint32_t ttl;
bool identified;
bool packet_received; /* At least one packet is received. Used by LLMNR. */
+ uint32_t requested_events;
/* only when using TCP fast open */
union sockaddr_union tfo_address;
@@ -68,7 +69,7 @@ struct DnsStream {
#if ENABLE_DNS_OVER_TLS
DnsTlsStreamData dnstls_data;
- int dnstls_events;
+ uint32_t dnstls_events;
#endif
sd_event_source *io_event_source;
diff --git a/src/resolve/resolved-dnstls-gnutls.c b/src/resolve/resolved-dnstls-gnutls.c
index 8610cacab6..8c8628ebbb 100644
--- a/src/resolve/resolved-dnstls-gnutls.c
+++ b/src/resolve/resolved-dnstls-gnutls.c
@@ -6,6 +6,7 @@
#include <gnutls/socket.h>
+#include "io-util.h"
#include "resolved-dns-stream.h"
#include "resolved-dnstls.h"
#include "resolved-manager.h"
@@ -13,7 +14,7 @@
#define TLS_PROTOCOL_PRIORITY "NORMAL:-VERS-ALL:+VERS-TLS1.3:+VERS-TLS1.2"
DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(gnutls_session_t, gnutls_deinit, NULL);
-static ssize_t dnstls_stream_writev(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) {
+static ssize_t dnstls_stream_vec_push(gnutls_transport_ptr_t p, const giovec_t *iov, int iovcnt) {
int r;
assert(p);
@@ -81,7 +82,7 @@ int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server) {
gnutls_handshake_set_timeout(gs, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
gnutls_transport_set_ptr2(gs, (gnutls_transport_ptr_t) (long) stream->fd, stream);
- gnutls_transport_set_vec_push_function(gs, &dnstls_stream_writev);
+ gnutls_transport_set_vec_push_function(gs, &dnstls_stream_vec_push);
stream->encrypted = true;
stream->dnstls_data.handshake = gnutls_handshake(gs);
@@ -163,15 +164,26 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) {
return 0;
}
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) {
ssize_t ss;
assert(stream);
assert(stream->encrypted);
assert(stream->dnstls_data.session);
- assert(buf);
+ assert(iov);
+ assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0);
+
+ gnutls_record_cork(stream->dnstls_data.session);
- ss = gnutls_record_send(stream->dnstls_data.session, buf, count);
+ for (size_t i = 0; i < iovcnt; i++) {
+ ss = gnutls_record_send(
+ stream->dnstls_data.session,
+ iov[i].iov_base, iov[i].iov_len);
+ if (ss < 0)
+ break;
+ }
+
+ ss = gnutls_record_uncork(stream->dnstls_data.session, 0);
if (ss < 0)
switch(ss) {
case GNUTLS_E_INTERRUPTED:
@@ -211,14 +223,6 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
return ss;
}
-bool dnstls_stream_has_buffered_data(DnsStream *stream) {
- assert(stream);
- assert(stream->encrypted);
- assert(stream->dnstls_data.session);
-
- return gnutls_record_check_pending(stream->dnstls_data.session) > 0;
-}
-
void dnstls_server_free(DnsServer *server) {
assert(server);
diff --git a/src/resolve/resolved-dnstls-openssl.c b/src/resolve/resolved-dnstls-openssl.c
index 7d264dd367..4d3a88c8da 100644
--- a/src/resolve/resolved-dnstls-openssl.c
+++ b/src/resolve/resolved-dnstls-openssl.c
@@ -292,15 +292,10 @@ int dnstls_stream_shutdown(DnsStream *stream, int error) {
return 0;
}
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
+static ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
int error, r;
ssize_t ss;
- assert(stream);
- assert(stream->encrypted);
- assert(stream->dnstls_data.ssl);
- assert(buf);
-
ERR_clear_error();
ss = r = SSL_write(stream->dnstls_data.ssl, buf, count);
if (r <= 0) {
@@ -329,6 +324,29 @@ ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count) {
return ss;
}
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt) {
+ _cleanup_free_ char *buf = NULL;
+ size_t count;
+
+ assert(stream);
+ assert(stream->encrypted);
+ assert(stream->dnstls_data.ssl);
+ assert(iov);
+ assert(IOVEC_TOTAL_SIZE(iov, iovcnt) > 0);
+
+ if (iovcnt == 1)
+ return dnstls_stream_write(stream, iov[0].iov_base, iov[0].iov_len);
+
+ /* As of now, OpenSSL can not accumulate multiple writes, so join into a
+ single buffer. Suboptimal, but better than multiple SSL_write calls. */
+ count = IOVEC_TOTAL_SIZE(iov, iovcnt);
+ buf = new(char, count);
+ for (size_t i = 0, pos = 0; i < iovcnt; pos += iov[i].iov_len, i++)
+ memcpy(buf + pos, iov[i].iov_base, iov[i].iov_len);
+
+ return dnstls_stream_write(stream, buf, count);
+}
+
ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
int error, r;
ssize_t ss;
@@ -343,7 +361,15 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
if (r <= 0) {
error = SSL_get_error(stream->dnstls_data.ssl, r);
if (IN_SET(error, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE)) {
- stream->dnstls_events = error == SSL_ERROR_WANT_READ ? EPOLLIN : EPOLLOUT;
+ /* If we receive SSL_ERROR_WANT_READ here, there are two possible scenarios:
+ * OpenSSL needs to renegotiate (so we want to get an EPOLLIN event), or
+ * There is no more application data is available, so we can just return
+ And apparently there's no nice way to distinguish between the two.
+ To handle this, never set EPOLLIN and just continue as usual.
+ If OpenSSL really wants to read due to renegotiation, it will tell us
+ again on SSL_write (at which point we will request EPOLLIN force a read);
+ or we will just eventually read data anyway while we wait for a packet */
+ stream->dnstls_events = error == SSL_ERROR_WANT_READ ? 0 : EPOLLOUT;
ss = -EAGAIN;
} else if (error == SSL_ERROR_ZERO_RETURN) {
stream->dnstls_events = 0;
@@ -367,14 +393,6 @@ ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count) {
return ss;
}
-bool dnstls_stream_has_buffered_data(DnsStream *stream) {
- assert(stream);
- assert(stream->encrypted);
- assert(stream->dnstls_data.ssl);
-
- return SSL_has_pending(stream->dnstls_data.ssl) > 0;
-}
-
void dnstls_server_free(DnsServer *server) {
assert(server);
diff --git a/src/resolve/resolved-dnstls.h b/src/resolve/resolved-dnstls.h
index ed214dc6c4..cda97e0b12 100644
--- a/src/resolve/resolved-dnstls.h
+++ b/src/resolve/resolved-dnstls.h
@@ -3,8 +3,8 @@
#if ENABLE_DNS_OVER_TLS
-#include <stdbool.h>
#include <stdint.h>
+#include <sys/uio.h>
typedef struct DnsServer DnsServer;
typedef struct DnsStream DnsStream;
@@ -27,9 +27,8 @@ int dnstls_stream_connect_tls(DnsStream *stream, DnsServer *server);
void dnstls_stream_free(DnsStream *stream);
int dnstls_stream_on_io(DnsStream *stream, uint32_t revents);
int dnstls_stream_shutdown(DnsStream *stream, int error);
-ssize_t dnstls_stream_write(DnsStream *stream, const char *buf, size_t count);
+ssize_t dnstls_stream_writev(DnsStream *stream, const struct iovec *iov, size_t iovcnt);
ssize_t dnstls_stream_read(DnsStream *stream, void *buf, size_t count);
-bool dnstls_stream_has_buffered_data(DnsStream *stream);
void dnstls_server_free(DnsServer *server);
diff --git a/src/resolve/test-resolved-stream.c b/src/resolve/test-resolved-stream.c
index f9428989f0..beaa855384 100644
--- a/src/resolve/test-resolved-stream.c
+++ b/src/resolve/test-resolved-stream.c
@@ -2,10 +2,12 @@
#include <arpa/inet.h>
#include <fcntl.h>
+#include <net/if.h>
#include <pthread.h>
#include <signal.h>
#include <stdlib.h>
#include <string.h>
+#include <sys/ioctl.h>
#include <sys/prctl.h>
#include <sys/socket.h>
#include <sys/wait.h>
@@ -13,6 +15,7 @@
#include "fd-util.h"
#include "log.h"
+#include "macro.h"
#include "process-util.h"
#include "resolved-dns-packet.h"
#include "resolved-dns-question.h"
@@ -144,7 +147,7 @@ static void *tls_dns_server(void *p) {
r = safe_fork_full("(test-resolved-stream-tls-openssl)", (int[]) { fd_server, fd_tls }, 2,
FORK_RESET_SIGNALS|FORK_CLOSE_ALL_FDS|FORK_DEATHSIG|FORK_LOG|FORK_REOPEN_LOG, &openssl_pid);
- assert(r >= 0);
+ assert_se(r >= 0);
if (r == 0) {
/* Child */
assert_se(dup2(fd_tls, STDIN_FILENO) >= 0);
@@ -200,6 +203,10 @@ static int on_stream_packet(DnsStream *stream, DnsPacket *p) {
return 0;
}
+static int on_stream_complete_do_nothing(DnsStream *s, int error) {
+ return 0;
+}
+
static void test_dns_stream(bool tls) {
Manager manager = {};
_cleanup_(dns_stream_unrefp) DnsStream *stream = NULL;
@@ -251,9 +258,10 @@ static void test_dns_stream(bool tls) {
/* systemd-resolved uses (and requires) the socket to be in nonblocking mode */
assert_se(fcntl(clientfd, F_SETFL, O_NONBLOCK) >= 0);
- /* Initialize DNS stream */
+ /* Initialize DNS stream (disabling the default self-destruction
+ behaviour when no complete callback is set) */
assert_se(dns_stream_new(&manager, &stream, DNS_STREAM_LOOKUP, DNS_PROTOCOL_DNS,
- TAKE_FD(clientfd), NULL, on_stream_packet, NULL,
+ TAKE_FD(clientfd), NULL, on_stream_packet, on_stream_complete_do_nothing,
DNS_STREAM_DEFAULT_TIMEOUT_USEC) >= 0);
#if ENABLE_DNS_OVER_TLS
if (tls) {
@@ -322,6 +330,24 @@ static void test_dns_stream(bool tls) {
log_info("test-resolved-stream: Finished %s test", tls ? "TLS" : "TCP");
}
+static void try_isolate_network(void) {
+ _cleanup_close_ int socket_fd = -1;
+
+ if (unshare(CLONE_NEWUSER | CLONE_NEWNET) < 0) {
+ log_warning("test-resolved-stream: Can't create user and network ns, running on host");
+ return;
+ }
+
+ /* Bring up the loopback interfaceon the newly created network namespace */
+ struct ifreq req = { .ifr_ifindex = 1 };
+ assert_se((socket_fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0)) >= 0);
+ assert_se(ioctl(socket_fd,SIOCGIFNAME,&req) >= 0);
+ assert_se(ioctl(socket_fd, SIOCGIFFLAGS, &req) >= 0);
+ assert_se(FLAGS_SET(req.ifr_flags, IFF_LOOPBACK));
+ req.ifr_flags |= IFF_UP;
+ assert_se(ioctl(socket_fd, SIOCSIFFLAGS, &req) >= 0);
+}
+
int main(int argc, char **argv) {
SERVER_ADDRESS = (struct sockaddr_in) {
.sin_family = AF_INET,
@@ -331,6 +357,8 @@ int main(int argc, char **argv) {
test_setup_logging(LOG_DEBUG);
+ try_isolate_network();
+
test_dns_stream(false);
#if ENABLE_DNS_OVER_TLS
if (system("openssl version >/dev/null 2>&1") != 0)