diff --git a/include/net/udp.h b/include/net/udp.h index 9e82cb391dea..a496e441645e 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -252,6 +252,17 @@ static inline int udp_rqueue_get(struct sock *sk) return sk_rmem_alloc_get(sk) - READ_ONCE(udp_sk(sk)->forward_deficit); } +static inline bool udp_sk_bound_dev_eq(struct net *net, int bound_dev_if, + int dif, int sdif) +{ +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + return inet_bound_dev_eq(!!net->ipv4.sysctl_udp_l3mdev_accept, + bound_dev_if, dif, sdif); +#else + return inet_bound_dev_eq(true, bound_dev_if, dif, sdif); +#endif +} + /* net/ipv4/udp.c */ void udp_destruct_sock(struct sock *sk); void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len); diff --git a/net/core/sock.c b/net/core/sock.c index 080a880a1761..7b304e454a38 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -567,6 +567,8 @@ static int sock_setbindtodevice(struct sock *sk, char __user *optval, lock_sock(sk); sk->sk_bound_dev_if = index; + if (sk->sk_prot->rehash) + sk->sk_prot->rehash(sk); sk_dst_reset(sk); release_sock(sk); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 1976fddb9e00..cf73c9194bb6 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -371,6 +371,7 @@ static int compute_score(struct sock *sk, struct net *net, { int score; struct inet_sock *inet; + bool dev_match; if (!net_eq(sock_net(sk), net) || udp_sk(sk)->udp_port_hash != hnum || @@ -398,15 +399,11 @@ static int compute_score(struct sock *sk, struct net *net, score += 4; } - if (sk->sk_bound_dev_if || exact_dif) { - bool dev_match = (sk->sk_bound_dev_if == dif || - sk->sk_bound_dev_if == sdif); - - if (!dev_match) - return -1; - if (sk->sk_bound_dev_if) - score += 4; - } + dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, + dif, sdif); + if (!dev_match) + return -1; + score += 4; if (sk->sk_incoming_cpu == raw_smp_processor_id()) score++; diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c index 1ede7a16a0be..bde08aa549f3 100644 --- a/net/ipv6/datagram.c +++ b/net/ipv6/datagram.c @@ -772,6 +772,7 @@ int ip6_datagram_send_ctl(struct net *net, struct sock *sk, case IPV6_2292PKTINFO: { struct net_device *dev = NULL; + int src_idx; if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct in6_pktinfo))) { err = -EINVAL; @@ -779,12 +780,15 @@ int ip6_datagram_send_ctl(struct net *net, struct sock *sk, } src_info = (struct in6_pktinfo *)CMSG_DATA(cmsg); + src_idx = src_info->ipi6_ifindex; - if (src_info->ipi6_ifindex) { + if (src_idx) { if (fl6->flowi6_oif && - src_info->ipi6_ifindex != fl6->flowi6_oif) + src_idx != fl6->flowi6_oif && + (sk->sk_bound_dev_if != fl6->flowi6_oif || + !sk_dev_equal_l3scope(sk, src_idx))) return -EINVAL; - fl6->flowi6_oif = src_info->ipi6_ifindex; + fl6->flowi6_oif = src_idx; } addr_type = __ipv6_addr_type(&src_info->ipi6_addr); diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index d2d97d07ef27..0559adc2f357 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -117,6 +117,7 @@ static int compute_score(struct sock *sk, struct net *net, { int score; struct inet_sock *inet; + bool dev_match; if (!net_eq(sock_net(sk), net) || udp_sk(sk)->udp_port_hash != hnum || @@ -144,15 +145,10 @@ static int compute_score(struct sock *sk, struct net *net, score++; } - if (sk->sk_bound_dev_if || exact_dif) { - bool dev_match = (sk->sk_bound_dev_if == dif || - sk->sk_bound_dev_if == sdif); - - if (!dev_match) - return -1; - if (sk->sk_bound_dev_if) - score++; - } + dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif); + if (!dev_match) + return -1; + score++; if (sk->sk_incoming_cpu == raw_smp_processor_id()) score++;