diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 6cb077b646a5..ef7031f8a304 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -14,6 +14,7 @@ #include #define MAX_MSG_FRAGS MAX_SKB_FRAGS +#define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1) enum __sk_action { __SK_DROP = 0, @@ -29,13 +30,15 @@ struct sk_msg_sg { u32 size; u32 copybreak; unsigned long copy; - /* The extra element is used for chaining the front and sections when - * the list becomes partitioned (e.g. end < start). The crypto APIs - * require the chaining. + /* The extra two elements: + * 1) used for chaining the front and sections when the list becomes + * partitioned (e.g. end < start). The crypto APIs require the + * chaining; + * 2) to chain tailer SG entries after the message. */ - struct scatterlist data[MAX_MSG_FRAGS + 1]; + struct scatterlist data[MAX_MSG_FRAGS + 2]; }; -static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS); +static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS); /* UAPI in filter.c depends on struct sk_msg_sg being first element. */ struct sk_msg { @@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes) static inline u32 sk_msg_iter_dist(u32 start, u32 end) { - return end >= start ? end - start : end + (MAX_MSG_FRAGS - start); + return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start); } #define sk_msg_iter_var_prev(var) \ do { \ if (var == 0) \ - var = MAX_MSG_FRAGS - 1; \ + var = NR_MSG_FRAG_IDS - 1; \ else \ var--; \ } while (0) @@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end) #define sk_msg_iter_var_next(var) \ do { \ var++; \ - if (var == MAX_MSG_FRAGS) \ + if (var == NR_MSG_FRAG_IDS) \ var = 0; \ } while (0) @@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg) static inline void sk_msg_init(struct sk_msg *msg) { - BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS); + BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS); memset(msg, 0, sizeof(*msg)); - sg_init_marker(msg->sg.data, MAX_MSG_FRAGS); + sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS); } static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src, @@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src) static inline bool sk_msg_full(const struct sk_msg *msg) { - return (msg->sg.end == msg->sg.start) && msg->sg.size; + return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS; } static inline u32 sk_msg_elem_used(const struct sk_msg *msg) { - if (sk_msg_full(msg)) - return MAX_MSG_FRAGS; - return sk_msg_iter_dist(msg->sg.start, msg->sg.end); } diff --git a/include/net/tls.h b/include/net/tls.h index 6ed91e82edd0..df630f5fc723 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -100,7 +100,6 @@ struct tls_rec { struct list_head list; int tx_ready; int tx_flags; - int inplace_crypto; struct sk_msg msg_plaintext; struct sk_msg msg_encrypted; @@ -377,7 +376,7 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx, int flags); int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, int flags); -bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx); +void tls_free_partial_record(struct sock *sk, struct tls_context *ctx); static inline struct tls_msg *tls_msg(struct sk_buff *skb) { diff --git a/net/core/filter.c b/net/core/filter.c index b0ed048585ba..f1e703eed3d2 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, WARN_ON_ONCE(last_sge == first_sge); shift = last_sge > first_sge ? last_sge - first_sge - 1 : - MAX_SKB_FRAGS - first_sge + last_sge - 1; + NR_MSG_FRAG_IDS - first_sge + last_sge - 1; if (!shift) goto out; @@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, do { u32 move_from; - if (i + shift >= MAX_MSG_FRAGS) - move_from = i + shift - MAX_MSG_FRAGS; + if (i + shift >= NR_MSG_FRAG_IDS) + move_from = i + shift - NR_MSG_FRAG_IDS; else move_from = i + shift; if (move_from == msg->sg.end) @@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, } while (1); msg->sg.end = msg->sg.end - shift > msg->sg.end ? - msg->sg.end - shift + MAX_MSG_FRAGS : + msg->sg.end - shift + NR_MSG_FRAG_IDS : msg->sg.end - shift; out: msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset; diff --git a/net/core/skmsg.c b/net/core/skmsg.c index a469d2124f3f..ded2d5227678 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) copied = skb->len; msg->sg.start = 0; msg->sg.size = copied; - msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; + msg->sg.end = num_sge; msg->skb = skb; sk_psock_queue_msg(psock, msg); diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 8a56e09cfb0e..e38705165ac9 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, struct sk_msg *msg, int *copied, int flags) { - bool cork = false, enospc = msg->sg.start == msg->sg.end; + bool cork = false, enospc = sk_msg_full(msg); struct sock *sk_redir; u32 tosend, delta = 0; int ret; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index bdca31ffe6da..b3da6c5ab999 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -209,24 +209,15 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } -bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx) +void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) { struct scatterlist *sg; - sg = ctx->partially_sent_record; - if (!sg) - return false; - - while (1) { + for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { put_page(sg_page(sg)); sk_mem_uncharge(sk, sg->length); - - if (sg_is_last(sg)) - break; - sg++; } ctx->partially_sent_record = NULL; - return true; } static void tls_write_space(struct sock *sk) diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index da9f9ce51e7b..2b2d0bae14a9 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -710,8 +710,7 @@ static int tls_push_record(struct sock *sk, int flags, } i = msg_pl->sg.start; - sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ? - &msg_en->sg.data[i] : &msg_pl->sg.data[i]); + sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); i = msg_en->sg.end; sk_msg_iter_var_prev(i); @@ -771,8 +770,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, policy = !(flags & MSG_SENDPAGE_NOPOLICY); psock = sk_psock_get(sk); - if (!psock || !policy) - return tls_push_record(sk, flags, record_type); + if (!psock || !policy) { + err = tls_push_record(sk, flags, record_type); + if (err) { + *copied -= sk_msg_free(sk, msg); + tls_free_open_rec(sk); + } + return err; + } more_data: enospc = sk_msg_full(msg); if (psock->eval == __SK_NONE) { @@ -970,8 +975,6 @@ alloc_encrypted: if (ret) goto fallback_to_reg_send; - rec->inplace_crypto = 0; - num_zc++; copied += try_to_copy; @@ -984,7 +987,7 @@ alloc_encrypted: num_async++; else if (ret == -ENOMEM) goto wait_for_memory; - else if (ret == -ENOSPC) + else if (ctx->open_rec && ret == -ENOSPC) goto rollback_iter; else if (ret != -EAGAIN) goto send_end; @@ -1053,11 +1056,12 @@ wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { trim_sgl: - tls_trim_both_msgs(sk, orig_size); + if (ctx->open_rec) + tls_trim_both_msgs(sk, orig_size); goto send_end; } - if (msg_en->sg.size < required_size) + if (ctx->open_rec && msg_en->sg.size < required_size) goto alloc_encrypted; } @@ -1169,7 +1173,6 @@ alloc_payload: tls_ctx->pending_open_record_frags = true; if (full_record || eor || sk_msg_full(msg_pl)) { - rec->inplace_crypto = 0; ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, record_type, &copied, flags); if (ret) { @@ -1190,11 +1193,13 @@ wait_for_sndbuf: wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { - tls_trim_both_msgs(sk, msg_pl->sg.size); + if (ctx->open_rec) + tls_trim_both_msgs(sk, msg_pl->sg.size); goto sendpage_end; } - goto alloc_payload; + if (ctx->open_rec) + goto alloc_payload; } if (num_async) { @@ -2084,7 +2089,8 @@ void tls_sw_release_resources_tx(struct sock *sk) /* Free up un-sent records in tx_list. First, free * the partially sent record if any at head of tx_list. */ - if (tls_free_partial_record(sk, tls_ctx)) { + if (tls_ctx->partially_sent_record) { + tls_free_partial_record(sk, tls_ctx); rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); list_del(&rec->list); diff --git a/tools/testing/selftests/bpf/test_sockmap.c b/tools/testing/selftests/bpf/test_sockmap.c index 3845144e2c91..4a851513c842 100644 --- a/tools/testing/selftests/bpf/test_sockmap.c +++ b/tools/testing/selftests/bpf/test_sockmap.c @@ -240,14 +240,14 @@ static int sockmap_init_sockets(int verbose) addr.sin_port = htons(S1_PORT); err = bind(s1, (struct sockaddr *)&addr, sizeof(addr)); if (err < 0) { - perror("bind s1 failed()\n"); + perror("bind s1 failed()"); return errno; } addr.sin_port = htons(S2_PORT); err = bind(s2, (struct sockaddr *)&addr, sizeof(addr)); if (err < 0) { - perror("bind s2 failed()\n"); + perror("bind s2 failed()"); return errno; } @@ -255,14 +255,14 @@ static int sockmap_init_sockets(int verbose) addr.sin_port = htons(S1_PORT); err = listen(s1, 32); if (err < 0) { - perror("listen s1 failed()\n"); + perror("listen s1 failed()"); return errno; } addr.sin_port = htons(S2_PORT); err = listen(s2, 32); if (err < 0) { - perror("listen s1 failed()\n"); + perror("listen s1 failed()"); return errno; } @@ -270,14 +270,14 @@ static int sockmap_init_sockets(int verbose) addr.sin_port = htons(S1_PORT); err = connect(c1, (struct sockaddr *)&addr, sizeof(addr)); if (err < 0 && errno != EINPROGRESS) { - perror("connect c1 failed()\n"); + perror("connect c1 failed()"); return errno; } addr.sin_port = htons(S2_PORT); err = connect(c2, (struct sockaddr *)&addr, sizeof(addr)); if (err < 0 && errno != EINPROGRESS) { - perror("connect c2 failed()\n"); + perror("connect c2 failed()"); return errno; } else if (err < 0) { err = 0; @@ -286,13 +286,13 @@ static int sockmap_init_sockets(int verbose) /* Accept Connecrtions */ p1 = accept(s1, NULL, NULL); if (p1 < 0) { - perror("accept s1 failed()\n"); + perror("accept s1 failed()"); return errno; } p2 = accept(s2, NULL, NULL); if (p2 < 0) { - perror("accept s1 failed()\n"); + perror("accept s1 failed()"); return errno; } @@ -332,6 +332,10 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt, int i, fp; file = fopen(".sendpage_tst.tmp", "w+"); + if (!file) { + perror("create file for sendpage"); + return 1; + } for (i = 0; i < iov_length * cnt; i++, k++) fwrite(&k, sizeof(char), 1, file); fflush(file); @@ -339,12 +343,17 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt, fclose(file); fp = open(".sendpage_tst.tmp", O_RDONLY); + if (fp < 0) { + perror("reopen file for sendpage"); + return 1; + } + clock_gettime(CLOCK_MONOTONIC, &s->start); for (i = 0; i < cnt; i++) { int sent = sendfile(fd, fp, NULL, iov_length); if (!drop && sent < 0) { - perror("send loop error:"); + perror("send loop error"); close(fp); return sent; } else if (drop && sent >= 0) { @@ -463,7 +472,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, int sent = sendmsg(fd, &msg, flags); if (!drop && sent < 0) { - perror("send loop error:"); + perror("send loop error"); goto out_errno; } else if (drop && sent >= 0) { printf("send loop error expected: %i\n", sent); @@ -499,7 +508,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, total_bytes -= txmsg_pop_total; err = clock_gettime(CLOCK_MONOTONIC, &s->start); if (err < 0) - perror("recv start time: "); + perror("recv start time"); while (s->bytes_recvd < total_bytes) { if (txmsg_cork) { timeout.tv_sec = 0; @@ -543,7 +552,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, if (recv < 0) { if (errno != EWOULDBLOCK) { clock_gettime(CLOCK_MONOTONIC, &s->end); - perror("recv failed()\n"); + perror("recv failed()"); goto out_errno; } } @@ -557,7 +566,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, errno = msg_verify_data(&msg, recv, chunk_sz); if (errno) { - perror("data verify msg failed\n"); + perror("data verify msg failed"); goto out_errno; } if (recvp) { @@ -565,7 +574,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt, recvp, chunk_sz); if (errno) { - perror("data verify msg_peek failed\n"); + perror("data verify msg_peek failed"); goto out_errno; } } @@ -654,7 +663,7 @@ static int sendmsg_test(struct sockmap_options *opt) err = 0; exit(err ? 1 : 0); } else if (rxpid == -1) { - perror("msg_loop_rx: "); + perror("msg_loop_rx"); return errno; } @@ -681,7 +690,7 @@ static int sendmsg_test(struct sockmap_options *opt) s.bytes_recvd, recvd_Bps, recvd_Bps/giga); exit(err ? 1 : 0); } else if (txpid == -1) { - perror("msg_loop_tx: "); + perror("msg_loop_tx"); return errno; } @@ -715,7 +724,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt) /* Ping/Pong data from client to server */ sc = send(c1, buf, sizeof(buf), 0); if (sc < 0) { - perror("send failed()\n"); + perror("send failed()"); return sc; } @@ -748,7 +757,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt) rc = recv(i, buf, sizeof(buf), 0); if (rc < 0) { if (errno != EWOULDBLOCK) { - perror("recv failed()\n"); + perror("recv failed()"); return rc; } } @@ -760,7 +769,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt) sc = send(i, buf, rc, 0); if (sc < 0) { - perror("send failed()\n"); + perror("send failed()"); return sc; } } diff --git a/tools/testing/selftests/bpf/xdping.c b/tools/testing/selftests/bpf/xdping.c index d60a343b1371..842d9155d36c 100644 --- a/tools/testing/selftests/bpf/xdping.c +++ b/tools/testing/selftests/bpf/xdping.c @@ -45,7 +45,7 @@ static int get_stats(int fd, __u16 count, __u32 raddr) printf("\nXDP RTT data:\n"); if (bpf_map_lookup_elem(fd, &raddr, &pinginfo)) { - perror("bpf_map_lookup elem: "); + perror("bpf_map_lookup elem"); return 1; } diff --git a/tools/testing/selftests/net/tls.c b/tools/testing/selftests/net/tls.c index 1c8f194d6556..46abcae47dee 100644 --- a/tools/testing/selftests/net/tls.c +++ b/tools/testing/selftests/net/tls.c @@ -268,6 +268,38 @@ TEST_F(tls, sendmsg_single) EXPECT_EQ(memcmp(buf, test_str, send_len), 0); } +#define MAX_FRAGS 64 +#define SEND_LEN 13 +TEST_F(tls, sendmsg_fragmented) +{ + char const *test_str = "test_sendmsg"; + char buf[SEND_LEN * MAX_FRAGS]; + struct iovec vec[MAX_FRAGS]; + struct msghdr msg; + int i, frags; + + for (frags = 1; frags <= MAX_FRAGS; frags++) { + for (i = 0; i < frags; i++) { + vec[i].iov_base = (char *)test_str; + vec[i].iov_len = SEND_LEN; + } + + memset(&msg, 0, sizeof(struct msghdr)); + msg.msg_iov = vec; + msg.msg_iovlen = frags; + + EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags); + EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL), + SEND_LEN * frags); + + for (i = 0; i < frags; i++) + EXPECT_EQ(memcmp(buf + SEND_LEN * i, + test_str, SEND_LEN), 0); + } +} +#undef MAX_FRAGS +#undef SEND_LEN + TEST_F(tls, sendmsg_large) { void *mem = malloc(16384); @@ -694,6 +726,34 @@ TEST_F(tls, recv_lowat) EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0); } +TEST_F(tls, recv_rcvbuf) +{ + char send_mem[4096]; + char recv_mem[4096]; + int rcv_buf = 1024; + + memset(send_mem, 0x1c, sizeof(send_mem)); + + EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVBUF, + &rcv_buf, sizeof(rcv_buf)), 0); + + EXPECT_EQ(send(self->fd, send_mem, 512, 0), 512); + memset(recv_mem, 0, sizeof(recv_mem)); + EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), 512); + EXPECT_EQ(memcmp(send_mem, recv_mem, 512), 0); + + if (self->notls) + return; + + EXPECT_EQ(send(self->fd, send_mem, 4096, 0), 4096); + memset(recv_mem, 0, sizeof(recv_mem)); + EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1); + EXPECT_EQ(errno, EMSGSIZE); + + EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1); + EXPECT_EQ(errno, EMSGSIZE); +} + TEST_F(tls, bidir) { char const *test_str = "test_read";