Something like this, perhaps - it'll ensure that io-wq workers get a
chance to flush out pending work, which should prevent the looping. I've
attached a basic test case. It'll issue a write that will fault, and
then try and cancel that as a way to trigger the TIF_NOTIFY_SIGNAL based
looping.
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index d80f94346199..e18926dbf20a 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -32,6 +32,7 @@
#include <linux/swapops.h>
#include <linux/miscdevice.h>
#include <linux/uio.h>
+#include <linux/io_uring.h>
static int sysctl_unprivileged_userfaultfd __read_mostly;
@@ -376,6 +377,8 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
*/
if (current->flags & (PF_EXITING|PF_DUMPCORE))
goto out;
+ else if (current->flags & PF_IO_WORKER)
+ io_worker_fault();
assert_fault_locked(vmf);
diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
index 85fe4e6b275c..d93dd7402a28 100644
--- a/include/linux/io_uring.h
+++ b/include/linux/io_uring.h
@@ -28,6 +28,7 @@ static inline void io_uring_free(struct task_struct *tsk)
if (tsk->io_uring)
__io_uring_free(tsk);
}
+void io_worker_fault(void);
#else
static inline void io_uring_task_cancel(void)
{
@@ -46,6 +47,9 @@ static inline bool io_is_uring_fops(struct file *file)
{
return false;
}
+static inline void io_worker_fault(void)
+{
+}
#endif
#endif
diff --git a/io_uring/io-wq.c b/io_uring/io-wq.c
index d52069b1177b..f74bea028ec7 100644
--- a/io_uring/io-wq.c
+++ b/io_uring/io-wq.c
@@ -1438,3 +1438,13 @@ static __init int io_wq_init(void)
return 0;
}
subsys_initcall(io_wq_init);
+
+void io_worker_fault(void)
+{
+ if (test_thread_flag(TIF_NOTIFY_SIGNAL))
+ clear_notify_signal();
+ if (test_thread_flag(TIF_NOTIFY_RESUME))
+ resume_user_mode_work(NULL);
+ if (task_work_pending(current))
+ task_work_run();
+}
--
Jens Axboe
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <poll.h>
#include <sys/mman.h>
#include <sys/ioctl.h>
#include <linux/mman.h>
#include <sys/uio.h>
#include <liburing.h>
#include <pthread.h>
#include <linux/userfaultfd.h>
#define HP_SIZE (2 * 1024 * 1024ULL)
#define NR_HUGEPAGES (3000)
#ifndef NR_userfaultfd
#define NR_userfaultfd 282
#endif
struct thread_data {
pthread_t thread;
pthread_barrier_t barrier;
int uffd;
};
static void *fault_handler(void *data)
{
struct thread_data *td = data;
struct uffd_msg msg;
struct pollfd pfd;
int ret, nready;
pthread_barrier_wait(&td->barrier);
do {
pfd.fd = td->uffd;
pfd.events = POLLIN;
nready = poll(&pfd, 1, -1);
if (nready < 0) {
perror("poll");
exit(1);
}
ret = read(td->uffd, &msg, sizeof(msg));
if (ret < 0) {
if (errno == EAGAIN)
continue;
perror("read");
exit(1);
}
if (msg.event != UFFD_EVENT_PAGEFAULT) {
printf("unspected event: %x\n", msg.event);
exit(1);
}
printf("Page fault\n");
printf("flags = %lx; ", (long) msg.arg.pagefault.flags);
printf("address = %lx\n", (long)msg.arg.pagefault.address);
} while (1);
return NULL;
}
static void do_io(struct io_uring *ring, void *buf, size_t len)
{
struct io_uring_sqe *sqe;
struct io_uring_cqe *cqe;
int fd, ret, i;
fd = open("/dev/nvme0n1", O_RDWR);
if (fd < 0) {
perror("open create");
return;
}
/* issue faulting write */
sqe = io_uring_get_sqe(ring);
io_uring_prep_write(sqe, fd, buf, len, 0);
sqe->user_data = 1;
io_uring_submit(ring);
printf("blocking issued\n");
sleep(1);
/* cancel above write */
sqe = io_uring_get_sqe(ring);
io_uring_prep_cancel64(sqe, 1, IORING_ASYNC_CANCEL_USERDATA);
sqe->user_data = 2;
io_uring_submit(ring);
printf("cancel issued\n");
sleep(1);
for (i = 0; i < 2; i++) {
again:
ret = io_uring_wait_cqe(ring, &cqe);
if (ret < 0) {
printf("wait: %d\n", ret);
if (ret == -EINTR)
goto again;
break;
}
printf("got res %d, %ld\n", cqe->res, (long) cqe->user_data);
io_uring_cqe_seen(ring, cqe);
}
}
static void sig_usr1(int sig)
{
printf("got USR1\n");
}
static int test(void)
{
struct uffdio_api api = { };
struct uffdio_register reg = { };
struct io_uring ring;
struct sigaction act = { };
struct thread_data td = { };
void *buf;
act.sa_handler = sig_usr1;
sigaction(SIGUSR1, &act, NULL);
io_uring_queue_init(4, &ring, 0);
buf = mmap(NULL, HP_SIZE, PROT_READ|PROT_WRITE,
MAP_PRIVATE | MAP_HUGETLB | MAP_HUGE_2MB | MAP_ANONYMOUS,
-1, 0);
if (buf == MAP_FAILED) {
perror("mmap");
return 1;
}
printf("got buf %p\n", buf);
td.uffd = syscall(NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
if (td.uffd < 0) {
perror("userfaultfd");
return 1;
}
api.api = UFFD_API;
if (ioctl(td.uffd, UFFDIO_API, &api) < 0) {
perror("ioctl UFFDIO_API");
return 1;
}
reg.range.start = (unsigned long) buf;
reg.range.len = HP_SIZE;
reg.mode = UFFDIO_REGISTER_MODE_MISSING;
if (ioctl(td.uffd, UFFDIO_REGISTER, ®) < 0) {
perror("ioctl UFFDIO_REGISTER");
return 1;
}
pthread_barrier_init(&td.barrier, NULL, 2);
pthread_create(&td.thread, NULL, fault_handler, &td);
pthread_barrier_wait(&td.barrier);
do_io(&ring, buf, HP_SIZE);
return 0;
}
int main(int argc, char *argv[])
{
return test();
}