mirror of
https://mirrors.bfsu.edu.cn/git/linux.git
synced 2024-11-11 12:28:41 +08:00
Merge branch 'virtio-vsock-some-updates-for-msg_peek-flag'
Arseniy Krasnov says: ==================== virtio/vsock: some updates for MSG_PEEK flag This patchset does several things around MSG_PEEK flag support. In general words it reworks MSG_PEEK test and adds support for this flag in SOCK_SEQPACKET logic. Here is per-patch description: 1) This is cosmetic change for SOCK_STREAM implementation of MSG_PEEK: 1) I think there is no need of "safe" mode walk here as there is no "unlink" of skbs inside loop (it is MSG_PEEK mode - we don't change queue). 2) Nested while loop is removed: in case of MSG_PEEK we just walk over skbs and copy data from each one. I guess this nested loop even didn't behave as loop - it always executed just for single iteration. 2) This adds MSG_PEEK support for SOCK_SEQPACKET. It could be implemented be reworking MSG_PEEK callback for SOCK_STREAM to support SOCK_SEQPACKET also, but I think it will be more simple and clear from potential bugs to implemented it as separate function thus not mixing logics for both types of socket. So I've added it as dedicated function. 3) This is reworked MSG_PEEK test for SOCK_STREAM. Previous version just sent single byte, then tried to read it with MSG_PEEK flag, then read it in normal way. New version is more complex: now sender uses buffer instead of single byte and this buffer is initialized with random values. Receiver tests several things: 1) Read empty socket with MSG_PEEK flag. 2) Read part of buffer with MSG_PEEK flag. 3) Read whole buffer with MSG_PEEK flag, then checks that it is same as buffer from 2) (limited by size of buffer from 2) of course). 4) Read whole buffer without any flags, then checks that it is same as buffer from 3). 4) This is MSG_PEEK test for SOCK_SEQPACKET. It works in the same way as for SOCK_STREAM, except it also checks combination of MSG_TRUNC and MSG_PEEK. ==================== Link: https://lore.kernel.org/r/20230725172912.1659970-1-AVKrasnov@sberdevices.ru Signed-off-by: Paolo Abeni <pabeni@redhat.com>
This commit is contained in:
commit
9d0cd5d25f
@ -348,37 +348,34 @@ virtio_transport_stream_do_peek(struct vsock_sock *vsk,
|
||||
size_t len)
|
||||
{
|
||||
struct virtio_vsock_sock *vvs = vsk->trans;
|
||||
size_t bytes, total = 0, off;
|
||||
struct sk_buff *skb, *tmp;
|
||||
int err = -EFAULT;
|
||||
struct sk_buff *skb;
|
||||
size_t total = 0;
|
||||
int err;
|
||||
|
||||
spin_lock_bh(&vvs->rx_lock);
|
||||
|
||||
skb_queue_walk_safe(&vvs->rx_queue, skb, tmp) {
|
||||
off = 0;
|
||||
skb_queue_walk(&vvs->rx_queue, skb) {
|
||||
size_t bytes;
|
||||
|
||||
bytes = len - total;
|
||||
if (bytes > skb->len)
|
||||
bytes = skb->len;
|
||||
|
||||
spin_unlock_bh(&vvs->rx_lock);
|
||||
|
||||
/* sk_lock is held by caller so no one else can dequeue.
|
||||
* Unlock rx_lock since memcpy_to_msg() may sleep.
|
||||
*/
|
||||
err = memcpy_to_msg(msg, skb->data, bytes);
|
||||
if (err)
|
||||
goto out;
|
||||
|
||||
total += bytes;
|
||||
|
||||
spin_lock_bh(&vvs->rx_lock);
|
||||
|
||||
if (total == len)
|
||||
break;
|
||||
|
||||
while (total < len && off < skb->len) {
|
||||
bytes = len - total;
|
||||
if (bytes > skb->len - off)
|
||||
bytes = skb->len - off;
|
||||
|
||||
/* sk_lock is held by caller so no one else can dequeue.
|
||||
* Unlock rx_lock since memcpy_to_msg() may sleep.
|
||||
*/
|
||||
spin_unlock_bh(&vvs->rx_lock);
|
||||
|
||||
err = memcpy_to_msg(msg, skb->data + off, bytes);
|
||||
if (err)
|
||||
goto out;
|
||||
|
||||
spin_lock_bh(&vvs->rx_lock);
|
||||
|
||||
total += bytes;
|
||||
off += bytes;
|
||||
}
|
||||
}
|
||||
|
||||
spin_unlock_bh(&vvs->rx_lock);
|
||||
@ -463,6 +460,63 @@ out:
|
||||
return err;
|
||||
}
|
||||
|
||||
static ssize_t
|
||||
virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk,
|
||||
struct msghdr *msg)
|
||||
{
|
||||
struct virtio_vsock_sock *vvs = vsk->trans;
|
||||
struct sk_buff *skb;
|
||||
size_t total, len;
|
||||
|
||||
spin_lock_bh(&vvs->rx_lock);
|
||||
|
||||
if (!vvs->msg_count) {
|
||||
spin_unlock_bh(&vvs->rx_lock);
|
||||
return 0;
|
||||
}
|
||||
|
||||
total = 0;
|
||||
len = msg_data_left(msg);
|
||||
|
||||
skb_queue_walk(&vvs->rx_queue, skb) {
|
||||
struct virtio_vsock_hdr *hdr;
|
||||
|
||||
if (total < len) {
|
||||
size_t bytes;
|
||||
int err;
|
||||
|
||||
bytes = len - total;
|
||||
if (bytes > skb->len)
|
||||
bytes = skb->len;
|
||||
|
||||
spin_unlock_bh(&vvs->rx_lock);
|
||||
|
||||
/* sk_lock is held by caller so no one else can dequeue.
|
||||
* Unlock rx_lock since memcpy_to_msg() may sleep.
|
||||
*/
|
||||
err = memcpy_to_msg(msg, skb->data, bytes);
|
||||
if (err)
|
||||
return err;
|
||||
|
||||
spin_lock_bh(&vvs->rx_lock);
|
||||
}
|
||||
|
||||
total += skb->len;
|
||||
hdr = virtio_vsock_hdr(skb);
|
||||
|
||||
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
|
||||
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
|
||||
msg->msg_flags |= MSG_EOR;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
spin_unlock_bh(&vvs->rx_lock);
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
|
||||
struct msghdr *msg,
|
||||
int flags)
|
||||
@ -557,9 +611,9 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
|
||||
int flags)
|
||||
{
|
||||
if (flags & MSG_PEEK)
|
||||
return -EOPNOTSUPP;
|
||||
|
||||
return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
|
||||
return virtio_transport_seqpacket_do_peek(vsk, msg);
|
||||
else
|
||||
return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
|
||||
}
|
||||
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
|
||||
|
||||
|
@ -255,35 +255,142 @@ static void test_stream_multiconn_server(const struct test_opts *opts)
|
||||
close(fds[i]);
|
||||
}
|
||||
|
||||
static void test_stream_msg_peek_client(const struct test_opts *opts)
|
||||
{
|
||||
int fd;
|
||||
#define MSG_PEEK_BUF_LEN 64
|
||||
|
||||
static void test_msg_peek_client(const struct test_opts *opts,
|
||||
bool seqpacket)
|
||||
{
|
||||
unsigned char buf[MSG_PEEK_BUF_LEN];
|
||||
ssize_t send_size;
|
||||
int fd;
|
||||
int i;
|
||||
|
||||
if (seqpacket)
|
||||
fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
|
||||
else
|
||||
fd = vsock_stream_connect(opts->peer_cid, 1234);
|
||||
|
||||
fd = vsock_stream_connect(opts->peer_cid, 1234);
|
||||
if (fd < 0) {
|
||||
perror("connect");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
send_byte(fd, 1, 0);
|
||||
for (i = 0; i < sizeof(buf); i++)
|
||||
buf[i] = rand() & 0xFF;
|
||||
|
||||
control_expectln("SRVREADY");
|
||||
|
||||
send_size = send(fd, buf, sizeof(buf), 0);
|
||||
|
||||
if (send_size < 0) {
|
||||
perror("send");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (send_size != sizeof(buf)) {
|
||||
fprintf(stderr, "Invalid send size %zi\n", send_size);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
close(fd);
|
||||
}
|
||||
|
||||
static void test_stream_msg_peek_server(const struct test_opts *opts)
|
||||
static void test_msg_peek_server(const struct test_opts *opts,
|
||||
bool seqpacket)
|
||||
{
|
||||
unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
|
||||
unsigned char buf_normal[MSG_PEEK_BUF_LEN];
|
||||
unsigned char buf_peek[MSG_PEEK_BUF_LEN];
|
||||
ssize_t res;
|
||||
int fd;
|
||||
|
||||
fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
|
||||
if (seqpacket)
|
||||
fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
|
||||
else
|
||||
fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
|
||||
|
||||
if (fd < 0) {
|
||||
perror("accept");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
recv_byte(fd, 1, MSG_PEEK);
|
||||
recv_byte(fd, 1, 0);
|
||||
/* Peek from empty socket. */
|
||||
res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT);
|
||||
if (res != -1) {
|
||||
fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (errno != EAGAIN) {
|
||||
perror("EAGAIN expected");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
control_writeln("SRVREADY");
|
||||
|
||||
/* Peek part of data. */
|
||||
res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK);
|
||||
if (res != sizeof(buf_half)) {
|
||||
fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
|
||||
sizeof(buf_half), res);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
/* Peek whole data. */
|
||||
res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK);
|
||||
if (res != sizeof(buf_peek)) {
|
||||
fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
|
||||
sizeof(buf_peek), res);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
/* Compare partial and full peek. */
|
||||
if (memcmp(buf_half, buf_peek, sizeof(buf_half))) {
|
||||
fprintf(stderr, "Partial peek data mismatch\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (seqpacket) {
|
||||
/* This type of socket supports MSG_TRUNC flag,
|
||||
* so check it with MSG_PEEK. We must get length
|
||||
* of the message.
|
||||
*/
|
||||
res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK |
|
||||
MSG_TRUNC);
|
||||
if (res != sizeof(buf_peek)) {
|
||||
fprintf(stderr,
|
||||
"recv(2) + MSG_PEEK | MSG_TRUNC, exp %zu, got %zi\n",
|
||||
sizeof(buf_half), res);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
res = recv(fd, buf_normal, sizeof(buf_normal), 0);
|
||||
if (res != sizeof(buf_normal)) {
|
||||
fprintf(stderr, "recv(2), expected %zu, got %zi\n",
|
||||
sizeof(buf_normal), res);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
/* Compare full peek and normal read. */
|
||||
if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) {
|
||||
fprintf(stderr, "Full peek data mismatch\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
close(fd);
|
||||
}
|
||||
|
||||
static void test_stream_msg_peek_client(const struct test_opts *opts)
|
||||
{
|
||||
return test_msg_peek_client(opts, false);
|
||||
}
|
||||
|
||||
static void test_stream_msg_peek_server(const struct test_opts *opts)
|
||||
{
|
||||
return test_msg_peek_server(opts, false);
|
||||
}
|
||||
|
||||
#define SOCK_BUF_SIZE (2 * 1024 * 1024)
|
||||
#define MAX_MSG_SIZE (32 * 1024)
|
||||
|
||||
@ -1053,6 +1160,16 @@ static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
|
||||
close(fd);
|
||||
}
|
||||
|
||||
static void test_seqpacket_msg_peek_client(const struct test_opts *opts)
|
||||
{
|
||||
return test_msg_peek_client(opts, true);
|
||||
}
|
||||
|
||||
static void test_seqpacket_msg_peek_server(const struct test_opts *opts)
|
||||
{
|
||||
return test_msg_peek_server(opts, true);
|
||||
}
|
||||
|
||||
static struct test_case test_cases[] = {
|
||||
{
|
||||
.name = "SOCK_STREAM connection reset",
|
||||
@ -1128,6 +1245,11 @@ static struct test_case test_cases[] = {
|
||||
.run_client = test_stream_virtio_skb_merge_client,
|
||||
.run_server = test_stream_virtio_skb_merge_server,
|
||||
},
|
||||
{
|
||||
.name = "SOCK_SEQPACKET MSG_PEEK",
|
||||
.run_client = test_seqpacket_msg_peek_client,
|
||||
.run_server = test_seqpacket_msg_peek_server,
|
||||
},
|
||||
{},
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user