summaryrefslogtreecommitdiff
path: root/ssh-agent.c
diff options
context:
space:
mode:
authordjm@openbsd.org <djm@openbsd.org>2017-07-19 01:15:02 +0000
committerDamien Miller <djm@mindrot.org>2017-07-21 14:17:33 +1000
commitfd0e8fa5f89d21290b1fb5f9d110ca4f113d81d9 (patch)
treea9b803cc12096cf74eabe13ff7dab974ad3bd09c /ssh-agent.c
parentb1e72df2b813ecc15bd0152167bf4af5f91c36d3 (diff)
downloadopenssh-git-fd0e8fa5f89d21290b1fb5f9d110ca4f113d81d9.tar.gz
upstream commit
switch from select() to poll() for the ssh-agent mainloop; ok markus Upstream-ID: 4a94888ee67b3fd948fd10693973beb12f802448
Diffstat (limited to 'ssh-agent.c')
-rw-r--r--ssh-agent.c312
1 files changed, 185 insertions, 127 deletions
diff --git a/ssh-agent.c b/ssh-agent.c
index eb8c2043..d858c247 100644
--- a/ssh-agent.c
+++ b/ssh-agent.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssh-agent.c,v 1.222 2017/07/01 13:50:45 djm Exp $ */
+/* $OpenBSD: ssh-agent.c,v 1.223 2017/07/19 01:15:02 djm Exp $ */
/*
* Author: Tatu Ylonen <ylo@cs.hut.fi>
* Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -60,6 +60,9 @@
#ifdef HAVE_PATHS_H
# include <paths.h>
#endif
+#ifdef HAVE_POLL_H
+# include <poll.h>
+#endif
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
@@ -91,6 +94,9 @@
# define DEFAULT_PKCS11_WHITELIST "/usr/lib*/*,/usr/local/lib*/*"
#endif
+/* Maximum accepted message length */
+#define AGENT_MAX_LEN (256*1024)
+
typedef enum {
AUTH_UNUSED,
AUTH_SOCKET,
@@ -634,30 +640,46 @@ send:
/* dispatch incoming messages */
-static void
-process_message(SocketEntry *e)
+static int
+process_message(u_int socknum)
{
u_int msg_len;
u_char type;
const u_char *cp;
int r;
+ SocketEntry *e;
+
+ if (socknum >= sockets_alloc) {
+ fatal("%s: socket number %u >= allocated %u",
+ __func__, socknum, sockets_alloc);
+ }
+ e = &sockets[socknum];
if (sshbuf_len(e->input) < 5)
- return; /* Incomplete message. */
+ return 0; /* Incomplete message header. */
cp = sshbuf_ptr(e->input);
msg_len = PEEK_U32(cp);
- if (msg_len > 256 * 1024) {
- close_socket(e);
- return;
+ if (msg_len > AGENT_MAX_LEN) {
+ debug("%s: socket %u (fd=%d) message too long %u > %u",
+ __func__, socknum, e->fd, msg_len, AGENT_MAX_LEN);
+ return -1;
}
if (sshbuf_len(e->input) < msg_len + 4)
- return;
+ return 0; /* Incomplete message body. */
/* move the current input to e->request */
sshbuf_reset(e->request);
if ((r = sshbuf_get_stringb(e->input, e->request)) != 0 ||
- (r = sshbuf_get_u8(e->request, &type)) != 0)
+ (r = sshbuf_get_u8(e->request, &type)) != 0) {
+ if (r == SSH_ERR_MESSAGE_INCOMPLETE ||
+ r == SSH_ERR_STRING_TOO_LARGE) {
+ debug("%s: buffer error: %s", __func__, ssh_err(r));
+ return -1;
+ }
fatal("%s: buffer error: %s", __func__, ssh_err(r));
+ }
+
+ debug("%s: socket %u (fd=%d) type %d", __func__, socknum, e->fd, type);
/* check wheter agent is locked */
if (locked && type != SSH_AGENTC_UNLOCK) {
@@ -671,10 +693,9 @@ process_message(SocketEntry *e)
/* send a fail message for all other request types */
send_status(e, 0);
}
- return;
+ return 0;
}
- debug("type %d", type);
switch (type) {
case SSH_AGENTC_LOCK:
case SSH_AGENTC_UNLOCK:
@@ -716,6 +737,7 @@ process_message(SocketEntry *e)
send_status(e, 0);
break;
}
+ return 0;
}
static void
@@ -757,19 +779,141 @@ new_socket(sock_type type, int fd)
}
static int
-prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
- struct timeval **tvpp)
+handle_socket_read(u_int socknum)
+{
+ struct sockaddr_un sunaddr;
+ socklen_t slen;
+ uid_t euid;
+ gid_t egid;
+ int fd;
+
+ slen = sizeof(sunaddr);
+ fd = accept(sockets[socknum].fd, (struct sockaddr *)&sunaddr, &slen);
+ if (fd < 0) {
+ error("accept from AUTH_SOCKET: %s", strerror(errno));
+ return -1;
+ }
+ if (getpeereid(fd, &euid, &egid) < 0) {
+ error("getpeereid %d failed: %s", fd, strerror(errno));
+ close(fd);
+ return -1;
+ }
+ if ((euid != 0) && (getuid() != euid)) {
+ error("uid mismatch: peer euid %u != uid %u",
+ (u_int) euid, (u_int) getuid());
+ close(fd);
+ return -1;
+ }
+ new_socket(AUTH_CONNECTION, fd);
+ return 0;
+}
+
+static int
+handle_conn_read(u_int socknum)
+{
+ char buf[1024];
+ ssize_t len;
+ int r;
+
+ if ((len = read(sockets[socknum].fd, buf, sizeof(buf))) <= 0) {
+ if (len == -1) {
+ if (errno == EAGAIN || errno == EINTR)
+ return 0;
+ error("%s: read error on socket %u (fd %d): %s",
+ __func__, socknum, sockets[socknum].fd,
+ strerror(errno));
+ }
+ return -1;
+ }
+ if ((r = sshbuf_put(sockets[socknum].input, buf, len)) != 0)
+ fatal("%s: buffer error: %s", __func__, ssh_err(r));
+ explicit_bzero(buf, sizeof(buf));
+ process_message(socknum);
+ return 0;
+}
+
+static int
+handle_conn_write(u_int socknum)
+{
+ ssize_t len;
+ int r;
+
+ if (sshbuf_len(sockets[socknum].output) == 0)
+ return 0; /* shouldn't happen */
+ if ((len = write(sockets[socknum].fd,
+ sshbuf_ptr(sockets[socknum].output),
+ sshbuf_len(sockets[socknum].output))) <= 0) {
+ if (len == -1) {
+ if (errno == EAGAIN || errno == EINTR)
+ return 0;
+ error("%s: read error on socket %u (fd %d): %s",
+ __func__, socknum, sockets[socknum].fd,
+ strerror(errno));
+ }
+ return -1;
+ }
+ if ((r = sshbuf_consume(sockets[socknum].output, len)) != 0)
+ fatal("%s: buffer error: %s", __func__, ssh_err(r));
+ return 0;
+}
+
+static void
+after_poll(struct pollfd *pfd, size_t npfd)
{
- u_int i, sz;
- int n = 0;
- static struct timeval tv;
+ size_t i;
+ u_int socknum;
+
+ for (i = 0; i < npfd; i++) {
+ if (pfd[i].revents == 0)
+ continue;
+ /* Find sockets entry */
+ for (socknum = 0; socknum < sockets_alloc; socknum++) {
+ if (sockets[socknum].type != AUTH_SOCKET &&
+ sockets[socknum].type != AUTH_CONNECTION)
+ continue;
+ if (pfd[i].fd == sockets[socknum].fd)
+ break;
+ }
+ if (socknum >= sockets_alloc) {
+ error("%s: no socket for fd %d", __func__, pfd[i].fd);
+ continue;
+ }
+ /* Process events */
+ switch (sockets[socknum].type) {
+ case AUTH_SOCKET:
+ if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
+ handle_socket_read(socknum) != 0)
+ close_socket(&sockets[socknum]);
+ break;
+ case AUTH_CONNECTION:
+ if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
+ handle_conn_read(socknum) != 0) {
+ close_socket(&sockets[socknum]);
+ break;
+ }
+ if ((pfd[i].revents & (POLLOUT|POLLHUP)) != 0 &&
+ handle_conn_write(socknum) != 0)
+ close_socket(&sockets[socknum]);
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+static int
+prepare_poll(struct pollfd **pfdp, size_t *npfdp, int *timeoutp)
+{
+ struct pollfd *pfd = *pfdp;
+ size_t i, j, npfd = 0;
time_t deadline;
+ /* Count active sockets */
for (i = 0; i < sockets_alloc; i++) {
switch (sockets[i].type) {
case AUTH_SOCKET:
case AUTH_CONNECTION:
- n = MAXIMUM(n, sockets[i].fd);
+ npfd++;
break;
case AUTH_UNUSED:
break;
@@ -778,28 +922,23 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
break;
}
}
+ if (npfd != *npfdp &&
+ (pfd = recallocarray(pfd, *npfdp, npfd, sizeof(*pfd))) == NULL)
+ fatal("%s: recallocarray failed", __func__);
+ *pfdp = pfd;
+ *npfdp = npfd;
- sz = howmany(n+1, NFDBITS) * sizeof(fd_mask);
- if (*fdrp == NULL || sz > *nallocp) {
- free(*fdrp);
- free(*fdwp);
- *fdrp = xmalloc(sz);
- *fdwp = xmalloc(sz);
- *nallocp = sz;
- }
- if (n < *fdl)
- debug("XXX shrink: %d < %d", n, *fdl);
- *fdl = n;
- memset(*fdrp, 0, sz);
- memset(*fdwp, 0, sz);
-
- for (i = 0; i < sockets_alloc; i++) {
+ for (i = j = 0; i < sockets_alloc; i++) {
switch (sockets[i].type) {
case AUTH_SOCKET:
case AUTH_CONNECTION:
- FD_SET(sockets[i].fd, *fdrp);
+ pfd[j].fd = sockets[i].fd;
+ pfd[j].revents = 0;
+ /* XXX backoff when input buffer full */
+ pfd[j].events = POLLIN;
if (sshbuf_len(sockets[i].output) > 0)
- FD_SET(sockets[i].fd, *fdwp);
+ pfd[j].events |= POLLOUT;
+ j++;
break;
default:
break;
@@ -810,99 +949,17 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
deadline = (deadline == 0) ? parent_alive_interval :
MINIMUM(deadline, parent_alive_interval);
if (deadline == 0) {
- *tvpp = NULL;
+ *timeoutp = INFTIM;
} else {
- tv.tv_sec = deadline;
- tv.tv_usec = 0;
- *tvpp = &tv;
+ if (deadline > INT_MAX / 1000)
+ *timeoutp = INT_MAX / 1000;
+ else
+ *timeoutp = deadline * 1000;
}
return (1);
}
static void
-after_select(fd_set *readset, fd_set *writeset)
-{
- struct sockaddr_un sunaddr;
- socklen_t slen;
- char buf[1024];
- int len, sock, r;
- u_int i, orig_alloc;
- uid_t euid;
- gid_t egid;
-
- for (i = 0, orig_alloc = sockets_alloc; i < orig_alloc; i++)
- switch (sockets[i].type) {
- case AUTH_UNUSED:
- break;
- case AUTH_SOCKET:
- if (FD_ISSET(sockets[i].fd, readset)) {
- slen = sizeof(sunaddr);
- sock = accept(sockets[i].fd,
- (struct sockaddr *)&sunaddr, &slen);
- if (sock < 0) {
- error("accept from AUTH_SOCKET: %s",
- strerror(errno));
- break;
- }
- if (getpeereid(sock, &euid, &egid) < 0) {
- error("getpeereid %d failed: %s",
- sock, strerror(errno));
- close(sock);
- break;
- }
- if ((euid != 0) && (getuid() != euid)) {
- error("uid mismatch: "
- "peer euid %u != uid %u",
- (u_int) euid, (u_int) getuid());
- close(sock);
- break;
- }
- new_socket(AUTH_CONNECTION, sock);
- }
- break;
- case AUTH_CONNECTION:
- if (sshbuf_len(sockets[i].output) > 0 &&
- FD_ISSET(sockets[i].fd, writeset)) {
- len = write(sockets[i].fd,
- sshbuf_ptr(sockets[i].output),
- sshbuf_len(sockets[i].output));
- if (len == -1 && (errno == EAGAIN ||
- errno == EWOULDBLOCK ||
- errno == EINTR))
- continue;
- if (len <= 0) {
- close_socket(&sockets[i]);
- break;
- }
- if ((r = sshbuf_consume(sockets[i].output,
- len)) != 0)
- fatal("%s: buffer error: %s",
- __func__, ssh_err(r));
- }
- if (FD_ISSET(sockets[i].fd, readset)) {
- len = read(sockets[i].fd, buf, sizeof(buf));
- if (len == -1 && (errno == EAGAIN ||
- errno == EWOULDBLOCK ||
- errno == EINTR))
- continue;
- if (len <= 0) {
- close_socket(&sockets[i]);
- break;
- }
- if ((r = sshbuf_put(sockets[i].input,
- buf, len)) != 0)
- fatal("%s: buffer error: %s",
- __func__, ssh_err(r));
- explicit_bzero(buf, sizeof(buf));
- process_message(&sockets[i]);
- }
- break;
- default:
- fatal("Unknown type %d", sockets[i].type);
- }
-}
-
-static void
cleanup_socket(void)
{
if (cleanup_pid != 0 && getpid() != cleanup_pid)
@@ -963,7 +1020,6 @@ main(int ac, char **av)
int sock, fd, ch, result, saved_errno;
u_int nalloc;
char *shell, *format, *pidstr, *agentsocket = NULL;
- fd_set *readsetp = NULL, *writesetp = NULL;
#ifdef HAVE_SETRLIMIT
struct rlimit rlim;
#endif
@@ -971,9 +1027,11 @@ main(int ac, char **av)
extern char *optarg;
pid_t pid;
char pidstrbuf[1 + 3 * sizeof pid];
- struct timeval *tvp = NULL;
size_t len;
mode_t prev_mask;
+ int timeout = INFTIM;
+ struct pollfd *pfd = NULL;
+ size_t npfd = 0;
ssh_malloc_init(); /* must be called before any mallocs */
/* Ensure that fds 0, 1 and 2 are open or directed to /dev/null */
@@ -1201,8 +1259,8 @@ skip:
platform_pledge_agent();
while (1) {
- prepare_select(&readsetp, &writesetp, &max_fd, &nalloc, &tvp);
- result = select(max_fd + 1, readsetp, writesetp, NULL, tvp);
+ prepare_poll(&pfd, &npfd, &timeout);
+ result = poll(pfd, npfd, timeout);
saved_errno = errno;
if (parent_alive_interval != 0)
check_parent_exists();
@@ -1210,9 +1268,9 @@ skip:
if (result < 0) {
if (saved_errno == EINTR)
continue;
- fatal("select: %s", strerror(saved_errno));
+ fatal("poll: %s", strerror(saved_errno));
} else if (result > 0)
- after_select(readsetp, writesetp);
+ after_poll(pfd, npfd);
}
/* NOTREACHED */
}