diff options
Diffstat (limited to 'net/mptcp/pm_netlink.c')
-rw-r--r-- | net/mptcp/pm_netlink.c | 200 |
1 files changed, 114 insertions, 86 deletions
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 4b5d795383cd..b5e8de6f7507 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -83,16 +83,6 @@ static bool addresses_equal(const struct mptcp_addr_info *a, return a->port == b->port; } -static bool address_zero(const struct mptcp_addr_info *addr) -{ - struct mptcp_addr_info zero; - - memset(&zero, 0, sizeof(zero)); - zero.family = addr->family; - - return addresses_equal(addr, &zero, true); -} - static void local_address(const struct sock_common *skc, struct mptcp_addr_info *addr) { @@ -120,7 +110,7 @@ static void remote_address(const struct sock_common *skc, } static bool lookup_subflow_by_saddr(const struct list_head *list, - struct mptcp_addr_info *saddr) + const struct mptcp_addr_info *saddr) { struct mptcp_subflow_context *subflow; struct mptcp_addr_info cur; @@ -138,7 +128,7 @@ static bool lookup_subflow_by_saddr(const struct list_head *list, } static bool lookup_subflow_by_daddr(const struct list_head *list, - struct mptcp_addr_info *daddr) + const struct mptcp_addr_info *daddr) { struct mptcp_subflow_context *subflow; struct mptcp_addr_info cur; @@ -157,10 +147,10 @@ static bool lookup_subflow_by_daddr(const struct list_head *list, static struct mptcp_pm_addr_entry * select_local_address(const struct pm_nl_pernet *pernet, - struct mptcp_sock *msk) + const struct mptcp_sock *msk) { + const struct sock *sk = (const struct sock *)msk; struct mptcp_pm_addr_entry *entry, *ret = NULL; - struct sock *sk = (struct sock *)msk; msk_owned_by_me(msk); @@ -190,7 +180,7 @@ select_local_address(const struct pm_nl_pernet *pernet, } static struct mptcp_pm_addr_entry * -select_signal_address(struct pm_nl_pernet *pernet, struct mptcp_sock *msk) +select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk) { struct mptcp_pm_addr_entry *entry, *ret = NULL; @@ -214,16 +204,16 @@ select_signal_address(struct pm_nl_pernet *pernet, struct mptcp_sock *msk) return ret; } -unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk) +unsigned int mptcp_pm_get_add_addr_signal_max(const struct mptcp_sock *msk) { - struct pm_nl_pernet *pernet; + const struct pm_nl_pernet *pernet; - pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); + pernet = net_generic(sock_net((const struct sock *)msk), pm_nl_pernet_id); return READ_ONCE(pernet->add_addr_signal_max); } EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max); -unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk) +unsigned int mptcp_pm_get_add_addr_accept_max(const struct mptcp_sock *msk) { struct pm_nl_pernet *pernet; @@ -232,7 +222,7 @@ unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk) } EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max); -unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk) +unsigned int mptcp_pm_get_subflows_max(const struct mptcp_sock *msk) { struct pm_nl_pernet *pernet; @@ -241,7 +231,7 @@ unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk) } EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max); -unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk) +unsigned int mptcp_pm_get_local_addr_max(const struct mptcp_sock *msk) { struct pm_nl_pernet *pernet; @@ -264,8 +254,8 @@ bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk) } struct mptcp_pm_add_entry * -mptcp_lookup_anno_list_by_saddr(struct mptcp_sock *msk, - struct mptcp_addr_info *addr) +mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) { struct mptcp_pm_add_entry *entry; @@ -346,7 +336,7 @@ out: struct mptcp_pm_add_entry * mptcp_pm_del_add_timer(struct mptcp_sock *msk, - struct mptcp_addr_info *addr, bool check_id) + const struct mptcp_addr_info *addr, bool check_id) { struct mptcp_pm_add_entry *entry; struct sock *sk = (struct sock *)msk; @@ -364,7 +354,7 @@ mptcp_pm_del_add_timer(struct mptcp_sock *msk, } static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, - struct mptcp_pm_addr_entry *entry) + const struct mptcp_pm_addr_entry *entry) { struct mptcp_pm_add_entry *add_entry = NULL; struct sock *sk = (struct sock *)msk; @@ -410,8 +400,8 @@ void mptcp_pm_free_anno_list(struct mptcp_sock *msk) } } -static bool lookup_address_in_vec(struct mptcp_addr_info *addrs, unsigned int nr, - struct mptcp_addr_info *addr) +static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned int nr, + const struct mptcp_addr_info *addr) { int i; @@ -493,9 +483,9 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info, } static int -lookup_id_by_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *addr) +lookup_id_by_addr(const struct pm_nl_pernet *pernet, const struct mptcp_addr_info *addr) { - struct mptcp_pm_addr_entry *entry; + const struct mptcp_pm_addr_entry *entry; int ret = -1; rcu_read_lock(); @@ -660,7 +650,6 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) unsigned int add_addr_accept_max; struct mptcp_addr_info remote; unsigned int subflows_max; - bool reset_port = false; int i, nr; add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); @@ -671,14 +660,15 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) msk->pm.remote.family); remote = msk->pm.remote; + mptcp_pm_announce_addr(msk, &remote, true); + mptcp_pm_nl_addr_send_ack(msk); + if (lookup_subflow_by_daddr(&msk->conn_list, &remote)) - goto add_addr_echo; + return; /* pick id 0 port, if none is provided the remote address */ - if (!remote.port) { - reset_port = true; + if (!remote.port) remote.port = sk->sk_dport; - } /* connect to the specified remote address, using whatever * local address the routing configuration will pick. @@ -694,14 +684,6 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) for (i = 0; i < nr; i++) __mptcp_subflow_connect(sk, &addrs[i], &remote); spin_lock_bh(&msk->pm.lock); - - /* be sure to echo exactly the received address */ - if (reset_port) - remote.port = 0; - -add_addr_echo: - mptcp_pm_announce_addr(msk, &remote, true); - mptcp_pm_nl_addr_send_ack(msk); } void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk) @@ -877,10 +859,18 @@ static bool address_use_port(struct mptcp_pm_addr_entry *entry) MPTCP_PM_ADDR_FLAG_SIGNAL; } +/* caller must ensure the RCU grace period is already elapsed */ +static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry) +{ + if (entry->lsk) + sock_release(entry->lsk); + kfree(entry); +} + static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, struct mptcp_pm_addr_entry *entry) { - struct mptcp_pm_addr_entry *cur; + struct mptcp_pm_addr_entry *cur, *del_entry = NULL; unsigned int addr_max; int ret = -EINVAL; @@ -901,8 +891,22 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, list_for_each_entry(cur, &pernet->local_addr_list, list) { if (addresses_equal(&cur->addr, &entry->addr, address_use_port(entry) && - address_use_port(cur))) - goto out; + address_use_port(cur))) { + /* allow replacing the exiting endpoint only if such + * endpoint is an implicit one and the user-space + * did not provide an endpoint id + */ + if (!(cur->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)) + goto out; + if (entry->addr.id) + goto out; + + pernet->addrs--; + entry->addr.id = cur->addr.id; + list_del_rcu(&cur->list); + del_entry = cur; + break; + } } if (!entry->addr.id) { @@ -938,6 +942,12 @@ find_next: out: spin_unlock_bh(&pernet->lock); + + /* just replaced an existing entry, free it */ + if (del_entry) { + synchronize_rcu(); + __mptcp_pm_release_addr_entry(del_entry); + } return ret; } @@ -1011,9 +1021,6 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) if (addresses_equal(&msk_local, &skc_local, false)) return 0; - if (address_zero(&skc_local)) - return 0; - pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); rcu_read_lock(); @@ -1036,7 +1043,7 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) entry->addr.id = 0; entry->addr.port = 0; entry->ifindex = 0; - entry->flags = 0; + entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT; entry->lsk = NULL; ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); if (ret < 0) @@ -1197,14 +1204,8 @@ skip_family: if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); - if (tb[MPTCP_PM_ADDR_ATTR_PORT]) { - if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { - NL_SET_ERR_MSG_ATTR(info->extack, attr, - "flags must have signal when using port"); - return -EINVAL; - } + if (tb[MPTCP_PM_ADDR_ATTR_PORT]) entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); - } return 0; } @@ -1250,6 +1251,22 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) if (ret < 0) return ret; + if (addr.addr.port && !(addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { + GENL_SET_ERR_MSG(info, "flags must have signal when using port"); + return -EINVAL; + } + + if (addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL && + addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) { + GENL_SET_ERR_MSG(info, "flags mustn't have both signal and fullmesh"); + return -EINVAL; + } + + if (addr.flags & MPTCP_PM_ADDR_FLAG_IMPLICIT) { + GENL_SET_ERR_MSG(info, "can't create IMPLICIT endpoint"); + return -EINVAL; + } + entry = kmalloc(sizeof(*entry), GFP_KERNEL); if (!entry) { GENL_SET_ERR_MSG(info, "can't allocate addr"); @@ -1301,7 +1318,7 @@ int mptcp_pm_get_flags_and_ifindex_by_id(struct net *net, unsigned int id, } static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, - struct mptcp_addr_info *addr) + const struct mptcp_addr_info *addr) { struct mptcp_pm_add_entry *entry; @@ -1316,7 +1333,7 @@ static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, } static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, - struct mptcp_addr_info *addr, + const struct mptcp_addr_info *addr, bool force) { struct mptcp_rm_list list = { .nr = 0 }; @@ -1334,11 +1351,12 @@ static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, } static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, - struct mptcp_addr_info *addr) + const struct mptcp_pm_addr_entry *entry) { - struct mptcp_sock *msk; - long s_slot = 0, s_num = 0; + const struct mptcp_addr_info *addr = &entry->addr; struct mptcp_rm_list list = { .nr = 0 }; + long s_slot = 0, s_num = 0; + struct mptcp_sock *msk; pr_debug("remove_id=%d", addr->id); @@ -1355,7 +1373,8 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, lock_sock(sk); remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); - mptcp_pm_remove_anno_addr(msk, addr, remove_subflow); + mptcp_pm_remove_anno_addr(msk, addr, remove_subflow && + !(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)); if (remove_subflow) mptcp_pm_remove_subflow(msk, &list); release_sock(sk); @@ -1368,14 +1387,6 @@ next: return 0; } -/* caller must ensure the RCU grace period is already elapsed */ -static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry) -{ - if (entry->lsk) - sock_release(entry->lsk); - kfree(entry); -} - static int mptcp_nl_remove_id_zero_address(struct net *net, struct mptcp_addr_info *addr) { @@ -1452,7 +1463,7 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) __clear_bit(entry->addr.id, pernet->id_bitmap); spin_unlock_bh(&pernet->lock); - mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); + mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), entry); synchronize_rcu(); __mptcp_pm_release_addr_entry(entry); @@ -1467,14 +1478,12 @@ static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, list_for_each_entry(entry, rm_list, list) { if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) && - alist.nr < MPTCP_RM_IDS_MAX && - slist.nr < MPTCP_RM_IDS_MAX) { - alist.ids[alist.nr++] = entry->addr.id; + slist.nr < MPTCP_RM_IDS_MAX) slist.ids[slist.nr++] = entry->addr.id; - } else if (remove_anno_list_by_saddr(msk, &entry->addr) && - alist.nr < MPTCP_RM_IDS_MAX) { + + if (remove_anno_list_by_saddr(msk, &entry->addr) && + alist.nr < MPTCP_RM_IDS_MAX) alist.ids[alist.nr++] = entry->addr.id; - } } if (alist.nr) { @@ -1751,9 +1760,20 @@ fail: return -EMSGSIZE; } -static int mptcp_nl_addr_backup(struct net *net, - struct mptcp_addr_info *addr, - u8 bkup) +static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk, + struct mptcp_addr_info *addr) +{ + struct mptcp_rm_list list = { .nr = 0 }; + + list.ids[list.nr++] = addr->id; + + mptcp_pm_nl_rm_subflow_received(msk, &list); + mptcp_pm_create_subflow_or_signal_addr(msk); +} + +static int mptcp_nl_set_flags(struct net *net, + struct mptcp_addr_info *addr, + u8 bkup, u8 changed) { long s_slot = 0, s_num = 0; struct mptcp_sock *msk; @@ -1767,7 +1787,10 @@ static int mptcp_nl_addr_backup(struct net *net, lock_sock(sk); spin_lock_bh(&msk->pm.lock); - ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup); + if (changed & MPTCP_PM_ADDR_FLAG_BACKUP) + ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup); + if (changed & MPTCP_PM_ADDR_FLAG_FULLMESH) + mptcp_pm_nl_fullmesh(msk, addr); spin_unlock_bh(&msk->pm.lock); release_sock(sk); @@ -1784,6 +1807,8 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry; struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + u8 changed, mask = MPTCP_PM_ADDR_FLAG_BACKUP | + MPTCP_PM_ADDR_FLAG_FULLMESH; struct net *net = sock_net(skb->sk); u8 bkup = 0, lookup_by_id = 0; int ret; @@ -1806,15 +1831,18 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) spin_unlock_bh(&pernet->lock); return -EINVAL; } + if ((addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) && + (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { + spin_unlock_bh(&pernet->lock); + return -EINVAL; + } - if (bkup) - entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP; - else - entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; + changed = (addr.flags ^ entry->flags) & mask; + entry->flags = (entry->flags & ~mask) | (addr.flags & mask); addr = *entry; spin_unlock_bh(&pernet->lock); - mptcp_nl_addr_backup(net, &addr.addr, bkup); + mptcp_nl_set_flags(net, &addr.addr, bkup, changed); return 0; } |