selftests/bpf: Introduce __attribute__((cleanup)) in create_pair()

Rewrite function to have (unneeded) socket descriptors automatically
close()d when leaving the scope. Make sure the "ownership" of fds is
correctly passed via take_fd(); i.e. descriptor returned to caller will
remain valid.

Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>
Tested-by: Jakub Sitnicki <jakub@cloudflare.com>
Suggested-by: Jakub Sitnicki <jakub@cloudflare.com>
Signed-off-by: Michal Luczaj <mhal@rbox.co>
Link: https://lore.kernel.org/r/20240731-selftest-sockmap-fixes-v2-6-08a0c73abed2@rbox.co
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
This commit is contained in:
Michal Luczaj 2024-07-31 12:01:31 +02:00 committed by Martin KaFai Lau
parent c9c70b28fa
commit 86149b4f5a

View File

@ -17,6 +17,17 @@
#define __always_unused __attribute__((__unused__)) #define __always_unused __attribute__((__unused__))
/* include/linux/cleanup.h */
#define __get_and_null(p, nullvalue) \
({ \
__auto_type __ptr = &(p); \
__auto_type __val = *__ptr; \
*__ptr = nullvalue; \
__val; \
})
#define take_fd(fd) __get_and_null(fd, -EBADF)
#define _FAIL(errnum, fmt...) \ #define _FAIL(errnum, fmt...) \
({ \ ({ \
error_at_line(0, (errnum), __func__, __LINE__, fmt); \ error_at_line(0, (errnum), __func__, __LINE__, fmt); \
@ -182,6 +193,14 @@
__ret; \ __ret; \
}) })
static inline void close_fd(int *fd)
{
if (*fd >= 0)
xclose(*fd);
}
#define __close_fd __attribute__((cleanup(close_fd)))
static inline int poll_connect(int fd, unsigned int timeout_sec) static inline int poll_connect(int fd, unsigned int timeout_sec)
{ {
struct timeval timeout = { .tv_sec = timeout_sec }; struct timeval timeout = { .tv_sec = timeout_sec };
@ -369,9 +388,10 @@ static inline int socket_loopback(int family, int sotype)
static inline int create_pair(int family, int sotype, int *p0, int *p1) static inline int create_pair(int family, int sotype, int *p0, int *p1)
{ {
__close_fd int s, c = -1, p = -1;
struct sockaddr_storage addr; struct sockaddr_storage addr;
socklen_t len = sizeof(addr); socklen_t len = sizeof(addr);
int s, c, p, err; int err;
s = socket_loopback(family, sotype); s = socket_loopback(family, sotype);
if (s < 0) if (s < 0)
@ -379,25 +399,23 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1)
err = xgetsockname(s, sockaddr(&addr), &len); err = xgetsockname(s, sockaddr(&addr), &len);
if (err) if (err)
goto close_s; return err;
c = xsocket(family, sotype, 0); c = xsocket(family, sotype, 0);
if (c < 0) { if (c < 0)
err = c; return c;
goto close_s;
}
err = connect(c, sockaddr(&addr), len); err = connect(c, sockaddr(&addr), len);
if (err) { if (err) {
if (errno != EINPROGRESS) { if (errno != EINPROGRESS) {
FAIL_ERRNO("connect"); FAIL_ERRNO("connect");
goto close_c; return err;
} }
err = poll_connect(c, IO_TIMEOUT_SEC); err = poll_connect(c, IO_TIMEOUT_SEC);
if (err) { if (err) {
FAIL_ERRNO("poll_connect"); FAIL_ERRNO("poll_connect");
goto close_c; return err;
} }
} }
@ -405,36 +423,29 @@ static inline int create_pair(int family, int sotype, int *p0, int *p1)
case SOCK_DGRAM: case SOCK_DGRAM:
err = xgetsockname(c, sockaddr(&addr), &len); err = xgetsockname(c, sockaddr(&addr), &len);
if (err) if (err)
goto close_c; return err;
err = xconnect(s, sockaddr(&addr), len); err = xconnect(s, sockaddr(&addr), len);
if (!err) { if (err)
*p0 = s;
*p1 = c;
return err; return err;
}
*p0 = take_fd(s);
break; break;
case SOCK_STREAM: case SOCK_STREAM:
case SOCK_SEQPACKET: case SOCK_SEQPACKET:
p = xaccept_nonblock(s, NULL, NULL); p = xaccept_nonblock(s, NULL, NULL);
if (p >= 0) { if (p < 0)
*p0 = p; return p;
*p1 = c;
goto close_s;
}
err = p; *p0 = take_fd(p);
break; break;
default: default:
FAIL("Unsupported socket type %#x", sotype); FAIL("Unsupported socket type %#x", sotype);
err = -EOPNOTSUPP; return -EOPNOTSUPP;
} }
close_c: *p1 = take_fd(c);
close(c); return 0;
close_s:
close(s);
return err;
} }
static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1, static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,