In RPAL, there are two roles: the sender (caller) and the receiver ( callee). This patch provides an interface for threads to register as a sender or a receiver with the kernel. Each sender and receiver has its own data structure, along with a block of memory shared between the user space and the kernel space, which is allocated through rpal_mmap(). Signed-off-by: Bo Li <libo.gcs85@xxxxxxxxxxxxx> --- arch/x86/rpal/Makefile | 2 +- arch/x86/rpal/internal.h | 7 ++ arch/x86/rpal/proc.c | 12 +++ arch/x86/rpal/service.c | 6 ++ arch/x86/rpal/thread.c | 165 +++++++++++++++++++++++++++++++++++++++ include/linux/rpal.h | 79 +++++++++++++++++++ include/linux/sched.h | 15 ++++ init/init_task.c | 2 + kernel/fork.c | 2 + 9 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 arch/x86/rpal/thread.c diff --git a/arch/x86/rpal/Makefile b/arch/x86/rpal/Makefile index a5926fc19334..89f745382c51 100644 --- a/arch/x86/rpal/Makefile +++ b/arch/x86/rpal/Makefile @@ -2,4 +2,4 @@ obj-$(CONFIG_RPAL) += rpal.o -rpal-y := service.o core.o mm.o proc.o +rpal-y := service.o core.o mm.o proc.o thread.o diff --git a/arch/x86/rpal/internal.h b/arch/x86/rpal/internal.h index 65fd14a26f0e..3559c9c6e868 100644 --- a/arch/x86/rpal/internal.h +++ b/arch/x86/rpal/internal.h @@ -34,3 +34,10 @@ static inline void rpal_put_shared_page(struct rpal_shared_page *rsp) int rpal_mmap(struct file *filp, struct vm_area_struct *vma); struct rpal_shared_page *rpal_find_shared_page(struct rpal_service *rs, unsigned long addr); + +/* thread.c */ +int rpal_register_sender(unsigned long addr); +int rpal_unregister_sender(void); +int rpal_register_receiver(unsigned long addr); +int rpal_unregister_receiver(void); +void exit_rpal_thread(void); diff --git a/arch/x86/rpal/proc.c b/arch/x86/rpal/proc.c index 86947dc233d0..8a1e4a8a2271 100644 --- a/arch/x86/rpal/proc.c +++ b/arch/x86/rpal/proc.c @@ -51,6 +51,18 @@ static long rpal_ioctl(struct file *file, unsigned int cmd, unsigned long arg) case RPAL_IOCTL_GET_SERVICE_ID: ret = put_user(cur->id, (int __user *)arg); break; + case RPAL_IOCTL_REGISTER_SENDER: + ret = rpal_register_sender(arg); + break; + case RPAL_IOCTL_UNREGISTER_SENDER: + ret = rpal_unregister_sender(); + break; + case RPAL_IOCTL_REGISTER_RECEIVER: + ret = rpal_register_receiver(arg); + break; + case RPAL_IOCTL_UNREGISTER_RECEIVER: + ret = rpal_unregister_receiver(); + break; default: return -EINVAL; } diff --git a/arch/x86/rpal/service.c b/arch/x86/rpal/service.c index f29a046fc22f..42fb719dbb2a 100644 --- a/arch/x86/rpal/service.c +++ b/arch/x86/rpal/service.c @@ -176,6 +176,7 @@ struct rpal_service *rpal_register_service(void) mutex_init(&rs->mutex); rs->nr_shared_pages = 0; INIT_LIST_HEAD(&rs->shared_pages); + atomic_set(&rs->thread_cnt, 0); rs->bad_service = false; rs->base = calculate_base_address(rs->id); @@ -216,6 +217,9 @@ void rpal_unregister_service(struct rpal_service *rs) if (!rs) return; + while (atomic_read(&rs->thread_cnt) != 0) + schedule(); + delete_service(rs); pr_debug("rpal: unregister service, id: %d, tgid: %d\n", rs->id, @@ -238,6 +242,8 @@ void exit_rpal(bool group_dead) if (!rs) return; + exit_rpal_thread(); + current->rpal_rs = NULL; rpal_put_service(rs); diff --git a/arch/x86/rpal/thread.c b/arch/x86/rpal/thread.c new file mode 100644 index 000000000000..7550ad94b63f --- /dev/null +++ b/arch/x86/rpal/thread.c @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * RPAL service level operations + * Copyright (c) 2025, ByteDance. All rights reserved. + * + * Author: Jiadong Sun <sunjiadong.lff@xxxxxxxxxxxxx> + */ + +#include <linux/rpal.h> + +#include "internal.h" + +static void rpal_common_data_init(struct rpal_common_data *rcd) +{ + rcd->bp_task = current; + rcd->service_id = rpal_current_service()->id; +} + +int rpal_register_sender(unsigned long addr) +{ + struct rpal_service *cur = rpal_current_service(); + struct rpal_shared_page *rsp; + struct rpal_sender_data *rsd; + long ret = 0; + + if (rpal_test_current_thread_flag(RPAL_SENDER_BIT)) { + ret = -EINVAL; + goto out; + } + + rsp = rpal_find_shared_page(cur, addr); + if (!rsp) { + ret = -EINVAL; + goto out; + } + + if (addr + sizeof(struct rpal_sender_call_context) > + rsp->user_start + rsp->npage * PAGE_SIZE) { + ret = -EINVAL; + goto put_shared_page; + } + + rsd = kzalloc(sizeof(*rsd), GFP_KERNEL); + if (rsd == NULL) { + ret = -ENOMEM; + goto put_shared_page; + } + + rpal_common_data_init(&rsd->rcd); + rsd->rsp = rsp; + rsd->scc = (struct rpal_sender_call_context *)(addr - rsp->user_start + + rsp->kernel_start); + + current->rpal_sd = rsd; + rpal_set_current_thread_flag(RPAL_SENDER_BIT); + + atomic_inc(&cur->thread_cnt); + + return 0; + +put_shared_page: + rpal_put_shared_page(rsp); +out: + return ret; +} + +int rpal_unregister_sender(void) +{ + struct rpal_service *cur = rpal_current_service(); + struct rpal_sender_data *rsd = current->rpal_sd; + long ret = 0; + + if (!rpal_test_current_thread_flag(RPAL_SENDER_BIT)) { + ret = -EINVAL; + goto out; + } + + rpal_put_shared_page(rsd->rsp); + rpal_clear_current_thread_flag(RPAL_SENDER_BIT); + kfree(rsd); + + atomic_dec(&cur->thread_cnt); + +out: + return ret; +} + +int rpal_register_receiver(unsigned long addr) +{ + struct rpal_service *cur = rpal_current_service(); + struct rpal_receiver_data *rrd; + struct rpal_shared_page *rsp; + long ret = 0; + + if (rpal_test_current_thread_flag(RPAL_RECEIVER_BIT)) { + ret = -EINVAL; + goto out; + } + + rsp = rpal_find_shared_page(cur, addr); + if (!rsp) { + ret = -EINVAL; + goto out; + } + + if (addr + sizeof(struct rpal_receiver_call_context) > + rsp->user_start + rsp->npage * PAGE_SIZE) { + ret = -EINVAL; + goto put_shared_page; + } + + rrd = kzalloc(sizeof(*rrd), GFP_KERNEL); + if (rrd == NULL) { + ret = -ENOMEM; + goto put_shared_page; + } + + rpal_common_data_init(&rrd->rcd); + rrd->rsp = rsp; + rrd->rcc = + (struct rpal_receiver_call_context *)(addr - rsp->user_start + + rsp->kernel_start); + + current->rpal_rd = rrd; + rpal_set_current_thread_flag(RPAL_RECEIVER_BIT); + + atomic_inc(&cur->thread_cnt); + + return 0; + +put_shared_page: + rpal_put_shared_page(rsp); +out: + return ret; +} + +int rpal_unregister_receiver(void) +{ + struct rpal_service *cur = rpal_current_service(); + struct rpal_receiver_data *rrd = current->rpal_rd; + long ret = 0; + + if (!rpal_test_current_thread_flag(RPAL_RECEIVER_BIT)) { + ret = -EINVAL; + goto out; + } + + rpal_put_shared_page(rrd->rsp); + rpal_clear_current_thread_flag(RPAL_RECEIVER_BIT); + kfree(rrd); + + atomic_dec(&cur->thread_cnt); + +out: + return ret; +} + +void exit_rpal_thread(void) +{ + if (rpal_test_current_thread_flag(RPAL_SENDER_BIT)) + rpal_unregister_sender(); + + if (rpal_test_current_thread_flag(RPAL_RECEIVER_BIT)) + rpal_unregister_receiver(); +} diff --git a/include/linux/rpal.h b/include/linux/rpal.h index 986dfbd16fc9..c33425e896af 100644 --- a/include/linux/rpal.h +++ b/include/linux/rpal.h @@ -79,6 +79,11 @@ extern unsigned long rpal_cap; +enum rpal_task_flag_bits { + RPAL_SENDER_BIT, + RPAL_RECEIVER_BIT, +}; + /* * Each RPAL process (a.k.a RPAL service) should have a pointer to * struct rpal_service in all its tasks' task_struct. @@ -117,6 +122,9 @@ struct rpal_service { int nr_shared_pages; struct list_head shared_pages; + /* sender/receiver thread count */ + atomic_t thread_cnt; + /* delayed service put work */ struct delayed_work delayed_put_work; @@ -149,10 +157,55 @@ struct rpal_shared_page { struct list_head list; }; +struct rpal_common_data { + /* back pointer to task_struct */ + struct task_struct *bp_task; + /* service id of rpal_service */ + int service_id; +}; + +/* User registers state */ +struct rpal_task_context { + u64 r15; + u64 r14; + u64 r13; + u64 r12; + u64 rbx; + u64 rbp; + u64 rip; + u64 rsp; +}; + +struct rpal_receiver_call_context { + struct rpal_task_context rtc; + int receiver_id; +}; + +struct rpal_receiver_data { + struct rpal_common_data rcd; + struct rpal_shared_page *rsp; + struct rpal_receiver_call_context *rcc; +}; + +struct rpal_sender_call_context { + struct rpal_task_context rtc; + int sender_id; +}; + +struct rpal_sender_data { + struct rpal_common_data rcd; + struct rpal_shared_page *rsp; + struct rpal_sender_call_context *scc; +}; + enum rpal_command_type { RPAL_CMD_GET_API_VERSION_AND_CAP, RPAL_CMD_GET_SERVICE_KEY, RPAL_CMD_GET_SERVICE_ID, + RPAL_CMD_REGISTER_SENDER, + RPAL_CMD_UNREGISTER_SENDER, + RPAL_CMD_REGISTER_RECEIVER, + RPAL_CMD_UNREGISTER_RECEIVER, RPAL_NR_CMD, }; @@ -165,6 +218,14 @@ enum rpal_command_type { _IOWR(RPAL_IOCTL_MAGIC, RPAL_CMD_GET_SERVICE_KEY, u64 *) #define RPAL_IOCTL_GET_SERVICE_ID \ _IOWR(RPAL_IOCTL_MAGIC, RPAL_CMD_GET_SERVICE_ID, int *) +#define RPAL_IOCTL_REGISTER_SENDER \ + _IOWR(RPAL_IOCTL_MAGIC, RPAL_CMD_REGISTER_SENDER, unsigned long) +#define RPAL_IOCTL_UNREGISTER_SENDER \ + _IO(RPAL_IOCTL_MAGIC, RPAL_CMD_UNREGISTER_SENDER) +#define RPAL_IOCTL_REGISTER_RECEIVER \ + _IOWR(RPAL_IOCTL_MAGIC, RPAL_CMD_REGISTER_RECEIVER, unsigned long) +#define RPAL_IOCTL_UNREGISTER_RECEIVER \ + _IO(RPAL_IOCTL_MAGIC, RPAL_CMD_UNREGISTER_RECEIVER) /** * @brief get new reference to a rpal service, a corresponding @@ -200,8 +261,26 @@ static inline struct rpal_service *rpal_current_service(void) { return current->rpal_rs; } + +static inline void rpal_set_current_thread_flag(unsigned long bit) +{ + set_bit(bit, ¤t->rpal_flag); +} + +static inline void rpal_clear_current_thread_flag(unsigned long bit) +{ + clear_bit(bit, ¤t->rpal_flag); +} + +static inline bool rpal_test_current_thread_flag(unsigned long bit) +{ + return test_bit(bit, ¤t->rpal_flag); +} #else static inline struct rpal_service *rpal_current_service(void) { return NULL; } +static inline void rpal_set_current_thread_flag(unsigned long bit) { } +static inline void rpal_clear_current_thread_flag(unsigned long bit) { } +static inline bool rpal_test_current_thread_flag(unsigned long bit) { return false; } #endif void rpal_unregister_service(struct rpal_service *rs); diff --git a/include/linux/sched.h b/include/linux/sched.h index ad35b197543c..5f25cc09fb71 100644 --- a/include/linux/sched.h +++ b/include/linux/sched.h @@ -72,6 +72,9 @@ struct rcu_node; struct reclaim_state; struct robust_list_head; struct root_domain; +struct rpal_common_data; +struct rpal_receiver_data; +struct rpal_sender_data; struct rpal_service; struct rq; struct sched_attr; @@ -1648,6 +1651,18 @@ struct task_struct { #ifdef CONFIG_RPAL struct rpal_service *rpal_rs; + unsigned long rpal_flag; + /* + * The first member of both rpal_sd and rpal_rd has a type + * of struct rpal_common_data. So if we do not care whether + * it is a struct rpal_sender_data or a struct rpal_receiver_data, + * use rpal_cd instead of rpal_sd or rpal_rd. + */ + union { + struct rpal_common_data *rpal_cd; + struct rpal_sender_data *rpal_sd; + struct rpal_receiver_data *rpal_rd; + }; #endif /* CPU-specific state of this task: */ diff --git a/init/init_task.c b/init/init_task.c index 0c5b1927da41..2eb08b96e66b 100644 --- a/init/init_task.c +++ b/init/init_task.c @@ -222,6 +222,8 @@ struct task_struct init_task __aligned(L1_CACHE_BYTES) = { #endif #ifdef CONFIG_RPAL .rpal_rs = NULL, + .rpal_flag = 0, + .rpal_cd = NULL, #endif }; EXPORT_SYMBOL(init_task); diff --git a/kernel/fork.c b/kernel/fork.c index 1d1c8484a8f2..01cd48eadf68 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -1220,6 +1220,8 @@ static struct task_struct *dup_task_struct(struct task_struct *orig, int node) #ifdef CONFIG_RPAL tsk->rpal_rs = NULL; + tsk->rpal_flag = 0; + tsk->rpal_cd = NULL; #endif return tsk; -- 2.20.1