net: mctp: separate key correlation across nets

Currently, we lookup sk_keys from the entire struct net_namespace, which
may contain multiple MCTP net IDs. In those cases we want to distinguish
between endpoints with the same EID but different net ID.

Add the net ID data to the struct mctp_sk_key, populate on add and
filter on this during route lookup.

For the ioctl interface, we use a default net of
MCTP_INITIAL_DEFAULT_NET (ie., what will be in use for single-net
configurations), but we'll extend the ioctl interface to provide
net-specific tag allocation in an upcoming change.

Signed-off-by: Jeremy Kerr <jk@codeconstruct.com.au>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
This commit is contained in:
Jeremy Kerr 2024-02-19 17:51:50 +08:00 committed by Paolo Abeni
parent a1f4cf5791
commit 43e6795574
4 changed files with 40 additions and 16 deletions

View File

@ -133,6 +133,7 @@ struct mctp_sock {
* - through an expiry timeout, on a per-socket timer
*/
struct mctp_sk_key {
unsigned int net;
mctp_eid_t peer_addr;
mctp_eid_t local_addr; /* MCTP_ADDR_ANY for local owned tags */
__u8 tag; /* incoming tag match; invert TO for local */
@ -254,6 +255,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
void mctp_key_unref(struct mctp_sk_key *key);
struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
unsigned int netid,
mctp_eid_t local, mctp_eid_t peer,
bool manual, u8 *tagp);

View File

@ -367,8 +367,8 @@ static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg)
if (ctl.flags)
return -EINVAL;
key = mctp_alloc_local_tag(msk, MCTP_ADDR_ANY, ctl.peer_addr,
true, &tag);
key = mctp_alloc_local_tag(msk, MCTP_INITIAL_DEFAULT_NET,
MCTP_ADDR_ANY, ctl.peer_addr, true, &tag);
if (IS_ERR(key))
return PTR_ERR(key);

View File

@ -107,9 +107,12 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
* and peer addresses, or either being ANY.
*/
static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
mctp_eid_t peer, u8 tag)
static bool mctp_key_match(struct mctp_sk_key *key, unsigned int net,
mctp_eid_t local, mctp_eid_t peer, u8 tag)
{
if (key->net != net)
return false;
if (!mctp_address_matches(key->local_addr, local))
return false;
@ -126,7 +129,7 @@ static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
* key exists.
*/
static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
mctp_eid_t peer,
unsigned int netid, mctp_eid_t peer,
unsigned long *irqflags)
__acquires(&key->lock)
{
@ -142,7 +145,7 @@ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry(key, &net->mctp.keys, hlist) {
if (!mctp_key_match(key, mh->dest, peer, tag))
if (!mctp_key_match(key, netid, mh->dest, peer, tag))
continue;
spin_lock(&key->lock);
@ -165,6 +168,7 @@ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
}
static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
unsigned int net,
mctp_eid_t local, mctp_eid_t peer,
u8 tag, gfp_t gfp)
{
@ -174,6 +178,7 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
if (!key)
return NULL;
key->net = net;
key->peer_addr = peer;
key->local_addr = local;
key->tag = tag;
@ -219,8 +224,8 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
}
hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
key->tag)) {
if (mctp_key_match(tmp, key->net, key->local_addr,
key->peer_addr, key->tag)) {
spin_lock(&tmp->lock);
if (tmp->valid)
rc = -EEXIST;
@ -361,6 +366,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
struct net *net = dev_net(skb->dev);
struct mctp_sock *msk;
struct mctp_hdr *mh;
unsigned int netid;
unsigned long f;
u8 tag, flags;
int rc;
@ -379,6 +385,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
/* grab header, advance data ptr */
mh = mctp_hdr(skb);
netid = mctp_cb(skb)->net;
skb_pull(skb, sizeof(struct mctp_hdr));
if (mh->ver != 1)
@ -392,7 +399,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
/* lookup socket / reasm context, exactly matching (src,dest,tag).
* we hold a ref on the key, and key->lock held.
*/
key = mctp_lookup_key(net, skb, mh->src, &f);
key = mctp_lookup_key(net, skb, netid, mh->src, &f);
if (flags & MCTP_HDR_FLAG_SOM) {
if (key) {
@ -406,7 +413,8 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* this lookup requires key->peer to be MCTP_ADDR_ANY,
* it doesn't match just any key->peer.
*/
any_key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
any_key = mctp_lookup_key(net, skb, netid,
MCTP_ADDR_ANY, &f);
if (any_key) {
msk = container_of(any_key->sk,
struct mctp_sock, sk);
@ -443,7 +451,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* packets for this message
*/
if (!key) {
key = mctp_key_alloc(msk, mh->dest, mh->src,
key = mctp_key_alloc(msk, netid, mh->dest, mh->src,
tag, GFP_ATOMIC);
if (!key) {
rc = -ENOMEM;
@ -637,6 +645,7 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
* it for the socket msk
*/
struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
unsigned int netid,
mctp_eid_t local, mctp_eid_t peer,
bool manual, u8 *tagp)
{
@ -651,7 +660,7 @@ struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
peer = MCTP_ADDR_ANY;
/* be optimistic, alloc now */
key = mctp_key_alloc(msk, local, peer, 0, GFP_KERNEL);
key = mctp_key_alloc(msk, netid, local, peer, 0, GFP_KERNEL);
if (!key)
return ERR_PTR(-ENOMEM);
@ -668,6 +677,10 @@ struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
* lock held, they don't change over the lifetime of the key.
*/
/* tags are net-specific */
if (tmp->net != netid)
continue;
/* if we don't own the tag, it can't conflict */
if (tmp->tag & MCTP_HDR_FLAG_TO)
continue;
@ -716,6 +729,7 @@ struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
}
static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
unsigned int netid,
mctp_eid_t daddr,
u8 req_tag, u8 *tagp)
{
@ -730,6 +744,9 @@ static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
spin_lock_irqsave(&mns->keys_lock, flags);
hlist_for_each_entry(tmp, &mns->keys, hlist) {
if (tmp->net != netid)
continue;
if (tmp->tag != req_tag)
continue;
@ -910,6 +927,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
struct mctp_sk_key *key;
struct mctp_hdr *hdr;
unsigned long flags;
unsigned int netid;
unsigned int mtu;
mctp_eid_t saddr;
bool ext_rt;
@ -960,16 +978,17 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
rc = 0;
}
spin_unlock_irqrestore(&rt->dev->addrs_lock, flags);
netid = READ_ONCE(rt->dev->net);
if (rc)
goto out_release;
if (req_tag & MCTP_TAG_OWNER) {
if (req_tag & MCTP_TAG_PREALLOC)
key = mctp_lookup_prealloc_tag(msk, daddr,
key = mctp_lookup_prealloc_tag(msk, netid, daddr,
req_tag, &tag);
else
key = mctp_alloc_local_tag(msk, saddr, daddr,
key = mctp_alloc_local_tag(msk, netid, saddr, daddr,
false, &tag);
if (IS_ERR(key)) {

View File

@ -552,6 +552,7 @@ static void mctp_test_route_input_sk_keys(struct kunit *test)
struct mctp_sock *msk;
struct socket *sock;
unsigned long flags;
unsigned int net;
int rc;
u8 c;
@ -559,6 +560,7 @@ static void mctp_test_route_input_sk_keys(struct kunit *test)
dev = mctp_test_create_dev();
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
net = READ_ONCE(dev->mdev->net);
rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
@ -570,8 +572,9 @@ static void mctp_test_route_input_sk_keys(struct kunit *test)
mns = &sock_net(sock->sk)->mctp;
/* set the incoming tag according to test params */
key = mctp_key_alloc(msk, params->key_local_addr, params->key_peer_addr,
params->key_tag, GFP_KERNEL);
key = mctp_key_alloc(msk, net, params->key_local_addr,
params->key_peer_addr, params->key_tag,
GFP_KERNEL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key);