diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 2ab45e3f48bf..7d519a46a844 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -820,13 +820,8 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb) /* Get L2TP header flags */ hdrflags = ntohs(*(__be16 *)ptr); - /* Check protocol version */ + /* Get protocol version */ version = hdrflags & L2TP_HDR_VER_MASK; - if (version != tunnel->version) { - pr_debug_ratelimited("%s: recv protocol version mismatch: got %d expected %d\n", - tunnel->name, version, tunnel->version); - goto invalid; - } /* Get length of L2TP packet */ length = skb->len; @@ -838,7 +833,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb) /* Skip flags */ ptr += 2; - if (tunnel->version == L2TP_HDR_VER_2) { + if (version == L2TP_HDR_VER_2) { /* If length is present, skip it */ if (hdrflags & L2TP_HDRFLAG_L) ptr += 2; @@ -855,7 +850,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb) struct l2tp_tunnel *alt_tunnel; alt_tunnel = l2tp_tunnel_get(tunnel->l2tp_net, tunnel_id); - if (!alt_tunnel || alt_tunnel->version != L2TP_HDR_VER_2) + if (!alt_tunnel) goto pass; tunnel = alt_tunnel; } @@ -869,6 +864,13 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb) ptr += 4; } + /* Check protocol version */ + if (version != tunnel->version) { + pr_debug_ratelimited("%s: recv protocol version mismatch: got %d expected %d\n", + tunnel->name, version, tunnel->version); + goto invalid; + } + /* Find the session context */ session = l2tp_tunnel_get_session(tunnel, session_id); if (!session || !session->recv_skb) {