diff --git a/fs/io_uring.c b/fs/io_uring.c index e5c2bb258db0..80db5898e119 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -4898,17 +4898,25 @@ static void io_poll_remove_double(struct io_kiocb *req) } } -static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error) +static bool io_poll_complete(struct io_kiocb *req, __poll_t mask, int error) { struct io_ring_ctx *ctx = req->ctx; + unsigned flags = IORING_CQE_F_MORE; - if (!error && req->poll.canceled) + if (!error && req->poll.canceled) { error = -ECANCELED; - - io_poll_remove_double(req); - req->poll.done = true; - io_cqring_fill_event(req, error ? error : mangle_poll(mask)); + req->poll.events |= EPOLLONESHOT; + } + if (error || (req->poll.events & EPOLLONESHOT)) { + io_poll_remove_double(req); + req->poll.done = true; + flags = 0; + } + if (!error) + error = mangle_poll(mask); + __io_cqring_fill_event(req, error, flags); io_commit_cqring(ctx); + return !(flags & IORING_CQE_F_MORE); } static void io_poll_task_func(struct callback_head *cb) @@ -4920,14 +4928,25 @@ static void io_poll_task_func(struct callback_head *cb) if (io_poll_rewait(req, &req->poll)) { spin_unlock_irq(&ctx->completion_lock); } else { - hash_del(&req->hash_node); - io_poll_complete(req, req->result, 0); + bool done, post_ev; + + post_ev = done = io_poll_complete(req, req->result, 0); + if (done) { + hash_del(&req->hash_node); + } else if (!(req->poll.events & EPOLLONESHOT)) { + post_ev = true; + req->result = 0; + add_wait_queue(req->poll.head, &req->poll.wait); + } spin_unlock_irq(&ctx->completion_lock); - nxt = io_put_req_find_next(req); - io_cqring_ev_posted(ctx); - if (nxt) - __io_req_task_submit(nxt); + if (post_ev) + io_cqring_ev_posted(ctx); + if (done) { + nxt = io_put_req_find_next(req); + if (nxt) + __io_req_task_submit(nxt); + } } } @@ -4941,6 +4960,8 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode, /* for instances that support it check for an event match first: */ if (mask && !(mask & poll->events)) return 0; + if (!(poll->events & EPOLLONESHOT)) + return poll->wait.func(&poll->wait, mode, sync, key); list_del_init(&wait->entry); @@ -5106,7 +5127,7 @@ static __poll_t __io_arm_poll_handler(struct io_kiocb *req, ipt->error = 0; mask = 0; } - if (mask || ipt->error) + if ((mask && (poll->events & EPOLLONESHOT)) || ipt->error) list_del_init(&poll->wait.entry); else if (cancel) WRITE_ONCE(poll->canceled, true); @@ -5149,7 +5170,7 @@ static bool io_arm_poll_handler(struct io_kiocb *req) req->flags |= REQ_F_POLLED; req->apoll = apoll; - mask = 0; + mask = EPOLLONESHOT; if (def->pollin) mask |= POLLIN | POLLRDNORM; if (def->pollout) @@ -5322,18 +5343,24 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head, static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) { struct io_poll_iocb *poll = &req->poll; - u32 events; + u32 events, flags; if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL)) return -EINVAL; - if (sqe->addr || sqe->ioprio || sqe->off || sqe->len || sqe->buf_index) + if (sqe->addr || sqe->ioprio || sqe->off || sqe->buf_index) + return -EINVAL; + flags = READ_ONCE(sqe->len); + if (flags & ~IORING_POLL_ADD_MULTI) return -EINVAL; events = READ_ONCE(sqe->poll32_events); #ifdef __BIG_ENDIAN events = swahw32(events); #endif - poll->events = demangle_poll(events) | (events & EPOLLEXCLUSIVE); + if (!flags) + events |= EPOLLONESHOT; + poll->events = demangle_poll(events) | + (events & (EPOLLEXCLUSIVE|EPOLLONESHOT)); return 0; } @@ -5357,7 +5384,8 @@ static int io_poll_add(struct io_kiocb *req, unsigned int issue_flags) if (mask) { io_cqring_ev_posted(ctx); - io_put_req(req); + if (poll->events & EPOLLONESHOT) + io_put_req(req); } return ipt.error; } diff --git a/include/uapi/linux/io_uring.h b/include/uapi/linux/io_uring.h index 2514eb6b1cf2..76c967621601 100644 --- a/include/uapi/linux/io_uring.h +++ b/include/uapi/linux/io_uring.h @@ -159,6 +159,16 @@ enum { */ #define SPLICE_F_FD_IN_FIXED (1U << 31) /* the last bit of __u32 */ +/* + * POLL_ADD flags. Note that since sqe->poll_events is the flag space, the + * command flags for POLL_ADD are stored in sqe->len. + * + * IORING_POLL_ADD_MULTI Multishot poll. Sets IORING_CQE_F_MORE if + * the poll handler will continue to report + * CQEs on behalf of the same SQE. + */ +#define IORING_POLL_ADD_MULTI (1U << 0) + /* * IO completion data structure (Completion Queue Entry) */ @@ -172,8 +182,10 @@ struct io_uring_cqe { * cqe->flags * * IORING_CQE_F_BUFFER If set, the upper 16 bits are the buffer ID + * IORING_CQE_F_MORE If set, parent SQE will generate more CQE entries */ #define IORING_CQE_F_BUFFER (1U << 0) +#define IORING_CQE_F_MORE (1U << 1) enum { IORING_CQE_BUFFER_SHIFT = 16,