summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/nl.c37
-rw-r--r--lib/socket.c37
-rw-r--r--tests/test-socket-creation.c20
3 files changed, 76 insertions, 18 deletions
diff --git a/lib/nl.c b/lib/nl.c
index 3866846..fd69de7 100644
--- a/lib/nl.c
+++ b/lib/nl.c
@@ -194,35 +194,50 @@ int nl_connect(struct nl_handle *handle, int protocol)
socklen_t addrlen;
handle->h_fd = socket(AF_NETLINK, SOCK_RAW, protocol);
- if (handle->h_fd < 0)
- return nl_error(1, "socket(AF_NETLINK, ...) failed");
+ if (handle->h_fd < 0) {
+ err = nl_error(1, "socket(AF_NETLINK, ...) failed");
+ goto errout;
+ }
if (!(handle->h_flags & NL_SOCK_BUFSIZE_SET)) {
err = nl_set_buffer_size(handle, 0, 0);
if (err < 0)
- return err;
+ goto errout;
}
err = bind(handle->h_fd, (struct sockaddr*) &handle->h_local,
sizeof(handle->h_local));
- if (err < 0)
- return nl_error(1, "bind() failed");
+ if (err < 0) {
+ err = nl_error(1, "bind() failed");
+ goto errout;
+ }
addrlen = sizeof(handle->h_local);
err = getsockname(handle->h_fd, (struct sockaddr *) &handle->h_local,
&addrlen);
- if (err < 0)
- return nl_error(1, "getsockname failed");
+ if (err < 0) {
+ err = nl_error(1, "getsockname failed");
+ goto errout;
+ }
- if (addrlen != sizeof(handle->h_local))
- return nl_error(EADDRNOTAVAIL, "Invalid address length");
+ if (addrlen != sizeof(handle->h_local)) {
+ err = nl_error(EADDRNOTAVAIL, "Invalid address length");
+ goto errout;
+ }
- if (handle->h_local.nl_family != AF_NETLINK)
- return nl_error(EPFNOSUPPORT, "Address format not supported");
+ if (handle->h_local.nl_family != AF_NETLINK) {
+ err = nl_error(EPFNOSUPPORT, "Address format not supported");
+ goto errout;
+ }
handle->h_proto = protocol;
return 0;
+errout:
+ close(handle->h_fd);
+ handle->h_fd = -1;
+
+ return err;
}
/**
diff --git a/lib/socket.c b/lib/socket.c
index f68e8cf..4e24c45 100644
--- a/lib/socket.c
+++ b/lib/socket.c
@@ -127,9 +127,13 @@ static uint32_t generate_local_port(void)
static void release_local_port(uint32_t port)
{
- int nr = port >> 22;
+ int nr;
- used_ports_map[nr / 32] &= ~(nr % 32);
+ if (port == UINT_MAX)
+ return;
+
+ nr = port >> 22;
+ used_ports_map[nr / 32] &= ~((nr % 32) + 1);
}
/**
@@ -147,11 +151,17 @@ static struct nl_handle *__alloc_handle(struct nl_cb *cb)
return NULL;
}
+ handle->h_fd = -1;
handle->h_cb = cb;
handle->h_local.nl_family = AF_NETLINK;
- handle->h_local.nl_pid = generate_local_port();
handle->h_peer.nl_family = AF_NETLINK;
handle->h_seq_expect = handle->h_seq_next = time(0);
+ handle->h_local.nl_pid = generate_local_port();
+ if (handle->h_local.nl_pid == UINT_MAX) {
+ nl_handle_destroy(handle);
+ nl_error(ENOBUFS, "Out of sequence numbers");
+ return NULL;
+ }
return handle;
}
@@ -200,6 +210,9 @@ void nl_handle_destroy(struct nl_handle *handle)
if (!handle)
return;
+ if (handle->h_fd >= 0)
+ close(handle->h_fd);
+
if (!(handle->h_flags & NL_OWN_PORT))
release_local_port(handle->h_local.nl_pid);
@@ -311,6 +324,9 @@ int nl_socket_add_membership(struct nl_handle *handle, int group)
{
int err;
+ if (handle->h_fd == -1)
+ return nl_error(EBADFD, "Socket not connected");
+
err = setsockopt(handle->h_fd, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
&group, sizeof(group));
if (err < 0)
@@ -335,6 +351,9 @@ int nl_socket_drop_membership(struct nl_handle *handle, int group)
{
int err;
+ if (handle->h_fd == -1)
+ return nl_error(EBADFD, "Socket not connected");
+
err = setsockopt(handle->h_fd, SOL_NETLINK, NETLINK_DROP_MEMBERSHIP,
&group, sizeof(group));
if (err < 0)
@@ -396,6 +415,9 @@ int nl_socket_get_fd(struct nl_handle *handle)
*/
int nl_socket_set_nonblocking(struct nl_handle *handle)
{
+ if (handle->h_fd == -1)
+ return nl_error(EBADFD, "Socket not connected");
+
if (fcntl(handle->h_fd, F_SETFL, O_NONBLOCK) < 0)
return nl_error(errno, "fcntl(F_SETFL, O_NONBLOCK) failed");
@@ -484,6 +506,9 @@ int nl_set_buffer_size(struct nl_handle *handle, int rxbuf, int txbuf)
if (txbuf <= 0)
txbuf = 32768;
+
+ if (handle->h_fd == -1)
+ return nl_error(EBADFD, "Socket not connected");
err = setsockopt(handle->h_fd, SOL_SOCKET, SO_SNDBUF,
&txbuf, sizeof(txbuf));
@@ -511,6 +536,9 @@ int nl_set_passcred(struct nl_handle *handle, int state)
{
int err;
+ if (handle->h_fd == -1)
+ return nl_error(EBADFD, "Socket not connected");
+
err = setsockopt(handle->h_fd, SOL_SOCKET, SO_PASSCRED,
&state, sizeof(state));
if (err < 0)
@@ -535,6 +563,9 @@ int nl_socket_recv_pktinfo(struct nl_handle *handle, int state)
{
int err;
+ if (handle->h_fd == -1)
+ return nl_error(EBADFD, "Socket not connected");
+
err = setsockopt(handle->h_fd, SOL_NETLINK, NETLINK_PKTINFO,
&state, sizeof(state));
if (err < 0)
diff --git a/tests/test-socket-creation.c b/tests/test-socket-creation.c
index 5a06661..4066eef 100644
--- a/tests/test-socket-creation.c
+++ b/tests/test-socket-creation.c
@@ -2,13 +2,25 @@
int main(int argc, char *argv[])
{
- struct nl_handle *h;
+ struct nl_handle *h[1025];
int i;
+ h[0] = nl_handle_alloc();
+ printf("Created handle with port 0x%x\n",
+ nl_socket_get_local_port(h[0]));
+ nl_handle_destroy(h[0]);
+ h[0] = nl_handle_alloc();
+ printf("Created handle with port 0x%x\n",
+ nl_socket_get_local_port(h[0]));
+ nl_handle_destroy(h[0]);
+
for (i = 0; i < 1025; i++) {
- h = nl_handle_alloc();
- printf("Created handle with port 0x%x\n",
- nl_socket_get_local_port(h));
+ h[i] = nl_handle_alloc();
+ if (h[i] == NULL)
+ nl_perror("Unable to allocate socket");
+ else
+ printf("Created handle with port 0x%x\n",
+ nl_socket_get_local_port(h[i]));
}
return 0;