diff options
-rw-r--r-- | lib/nl.c | 37 | ||||
-rw-r--r-- | lib/socket.c | 37 | ||||
-rw-r--r-- | tests/test-socket-creation.c | 20 |
3 files changed, 76 insertions, 18 deletions
@@ -194,35 +194,50 @@ int nl_connect(struct nl_handle *handle, int protocol) socklen_t addrlen; handle->h_fd = socket(AF_NETLINK, SOCK_RAW, protocol); - if (handle->h_fd < 0) - return nl_error(1, "socket(AF_NETLINK, ...) failed"); + if (handle->h_fd < 0) { + err = nl_error(1, "socket(AF_NETLINK, ...) failed"); + goto errout; + } if (!(handle->h_flags & NL_SOCK_BUFSIZE_SET)) { err = nl_set_buffer_size(handle, 0, 0); if (err < 0) - return err; + goto errout; } err = bind(handle->h_fd, (struct sockaddr*) &handle->h_local, sizeof(handle->h_local)); - if (err < 0) - return nl_error(1, "bind() failed"); + if (err < 0) { + err = nl_error(1, "bind() failed"); + goto errout; + } addrlen = sizeof(handle->h_local); err = getsockname(handle->h_fd, (struct sockaddr *) &handle->h_local, &addrlen); - if (err < 0) - return nl_error(1, "getsockname failed"); + if (err < 0) { + err = nl_error(1, "getsockname failed"); + goto errout; + } - if (addrlen != sizeof(handle->h_local)) - return nl_error(EADDRNOTAVAIL, "Invalid address length"); + if (addrlen != sizeof(handle->h_local)) { + err = nl_error(EADDRNOTAVAIL, "Invalid address length"); + goto errout; + } - if (handle->h_local.nl_family != AF_NETLINK) - return nl_error(EPFNOSUPPORT, "Address format not supported"); + if (handle->h_local.nl_family != AF_NETLINK) { + err = nl_error(EPFNOSUPPORT, "Address format not supported"); + goto errout; + } handle->h_proto = protocol; return 0; +errout: + close(handle->h_fd); + handle->h_fd = -1; + + return err; } /** diff --git a/lib/socket.c b/lib/socket.c index f68e8cf..4e24c45 100644 --- a/lib/socket.c +++ b/lib/socket.c @@ -127,9 +127,13 @@ static uint32_t generate_local_port(void) static void release_local_port(uint32_t port) { - int nr = port >> 22; + int nr; - used_ports_map[nr / 32] &= ~(nr % 32); + if (port == UINT_MAX) + return; + + nr = port >> 22; + used_ports_map[nr / 32] &= ~((nr % 32) + 1); } /** @@ -147,11 +151,17 @@ static struct nl_handle *__alloc_handle(struct nl_cb *cb) return NULL; } + handle->h_fd = -1; handle->h_cb = cb; handle->h_local.nl_family = AF_NETLINK; - handle->h_local.nl_pid = generate_local_port(); handle->h_peer.nl_family = AF_NETLINK; handle->h_seq_expect = handle->h_seq_next = time(0); + handle->h_local.nl_pid = generate_local_port(); + if (handle->h_local.nl_pid == UINT_MAX) { + nl_handle_destroy(handle); + nl_error(ENOBUFS, "Out of sequence numbers"); + return NULL; + } return handle; } @@ -200,6 +210,9 @@ void nl_handle_destroy(struct nl_handle *handle) if (!handle) return; + if (handle->h_fd >= 0) + close(handle->h_fd); + if (!(handle->h_flags & NL_OWN_PORT)) release_local_port(handle->h_local.nl_pid); @@ -311,6 +324,9 @@ int nl_socket_add_membership(struct nl_handle *handle, int group) { int err; + if (handle->h_fd == -1) + return nl_error(EBADFD, "Socket not connected"); + err = setsockopt(handle->h_fd, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &group, sizeof(group)); if (err < 0) @@ -335,6 +351,9 @@ int nl_socket_drop_membership(struct nl_handle *handle, int group) { int err; + if (handle->h_fd == -1) + return nl_error(EBADFD, "Socket not connected"); + err = setsockopt(handle->h_fd, SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, &group, sizeof(group)); if (err < 0) @@ -396,6 +415,9 @@ int nl_socket_get_fd(struct nl_handle *handle) */ int nl_socket_set_nonblocking(struct nl_handle *handle) { + if (handle->h_fd == -1) + return nl_error(EBADFD, "Socket not connected"); + if (fcntl(handle->h_fd, F_SETFL, O_NONBLOCK) < 0) return nl_error(errno, "fcntl(F_SETFL, O_NONBLOCK) failed"); @@ -484,6 +506,9 @@ int nl_set_buffer_size(struct nl_handle *handle, int rxbuf, int txbuf) if (txbuf <= 0) txbuf = 32768; + + if (handle->h_fd == -1) + return nl_error(EBADFD, "Socket not connected"); err = setsockopt(handle->h_fd, SOL_SOCKET, SO_SNDBUF, &txbuf, sizeof(txbuf)); @@ -511,6 +536,9 @@ int nl_set_passcred(struct nl_handle *handle, int state) { int err; + if (handle->h_fd == -1) + return nl_error(EBADFD, "Socket not connected"); + err = setsockopt(handle->h_fd, SOL_SOCKET, SO_PASSCRED, &state, sizeof(state)); if (err < 0) @@ -535,6 +563,9 @@ int nl_socket_recv_pktinfo(struct nl_handle *handle, int state) { int err; + if (handle->h_fd == -1) + return nl_error(EBADFD, "Socket not connected"); + err = setsockopt(handle->h_fd, SOL_NETLINK, NETLINK_PKTINFO, &state, sizeof(state)); if (err < 0) diff --git a/tests/test-socket-creation.c b/tests/test-socket-creation.c index 5a06661..4066eef 100644 --- a/tests/test-socket-creation.c +++ b/tests/test-socket-creation.c @@ -2,13 +2,25 @@ int main(int argc, char *argv[]) { - struct nl_handle *h; + struct nl_handle *h[1025]; int i; + h[0] = nl_handle_alloc(); + printf("Created handle with port 0x%x\n", + nl_socket_get_local_port(h[0])); + nl_handle_destroy(h[0]); + h[0] = nl_handle_alloc(); + printf("Created handle with port 0x%x\n", + nl_socket_get_local_port(h[0])); + nl_handle_destroy(h[0]); + for (i = 0; i < 1025; i++) { - h = nl_handle_alloc(); - printf("Created handle with port 0x%x\n", - nl_socket_get_local_port(h)); + h[i] = nl_handle_alloc(); + if (h[i] == NULL) + nl_perror("Unable to allocate socket"); + else + printf("Created handle with port 0x%x\n", + nl_socket_get_local_port(h[i])); } return 0; |