[PATCH v5 2/2] LoongArch: BPF: Fix tailcall hierarchy

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



In specific use cases combining tailcalls and BPF-to-BPF calls,
MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt
back-propagation from callee to caller。This patch fixes this
tailcall issue caused by abusing the tailcall in bpf2bpf feature
on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy".

push tail_call_cnt_ptr and tail_call_cnt into the stack,
tail_call_cnt_ptr is passed between tailcall and bpf2bpf,
uses tail_call_cnt_ptr to increment tail_call_cnt.

Fixes: bb035ef0cc91 ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls")
Fixes: 5dc615520c4d ("LoongArch: Add BPF JIT support")

Reviewed-by: Geliang Tang <geliang@xxxxxxxxxx>
Reviewed-by: Hengqi Chen <hengqi.chen@xxxxxxxxx>
Signed-off-by: Haoran Jiang <jianghaoran@xxxxxxxxxx>
---
 arch/loongarch/net/bpf_jit.c | 150 ++++++++++++++++++++++++-----------
 1 file changed, 104 insertions(+), 46 deletions(-)

diff --git a/arch/loongarch/net/bpf_jit.c b/arch/loongarch/net/bpf_jit.c
index 3f855d87eee3..e0274ae84651 100644
--- a/arch/loongarch/net/bpf_jit.c
+++ b/arch/loongarch/net/bpf_jit.c
@@ -17,10 +17,7 @@
 #define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4)
 
 #define REG_TCC		LOONGARCH_GPR_A6
-#define TCC_SAVED	LOONGARCH_GPR_S5
-
-#define SAVE_RA		BIT(0)
-#define SAVE_TCC	BIT(1)
+#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80)
 
 static const int regmap[] = {
 	/* return value from in-kernel function, and exit value for eBPF program */
@@ -42,32 +39,55 @@ static const int regmap[] = {
 	[BPF_REG_AX] = LOONGARCH_GPR_T0,
 };
 
-static void mark_call(struct jit_ctx *ctx)
+static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset)
 {
-	ctx->flags |= SAVE_RA;
-}
+	const struct bpf_prog *prog = ctx->prog;
+	const bool is_main_prog = !bpf_is_subprog(prog);
 
-static void mark_tail_call(struct jit_ctx *ctx)
-{
-	ctx->flags |= SAVE_TCC;
-}
+	if (is_main_prog) {
+		/*
+		 * LOONGARCH_GPR_T3 = MAX_TAIL_CALL_CNT
+		 * if (REG_TCC > T3 )
+		 *	std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+		 * else
+		 *	std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+		 *	REG_TCC = LOONGARCH_GPR_SP + store_offset
+		 *
+		 * std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+		 *
+		 * The purpose of this code is to first push the TCC into stack,
+		 * and then push the address of TCC into stack.
+		 * In cases where bpf2bpf and tailcall are used in combination,
+		 * the value in REG_TCC may be a count or an address,
+		 * these two cases need to be judged and handled separately.
+		 */
+		emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+		*store_offset -= sizeof(long);
 
-static bool seen_call(struct jit_ctx *ctx)
-{
-	return (ctx->flags & SAVE_RA);
-}
+		emit_cond_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
 
-static bool seen_tail_call(struct jit_ctx *ctx)
-{
-	return (ctx->flags & SAVE_TCC);
-}
+		/* If REG_TCC < MAX_TAIL_CALL_CNT, the value in REG_TCC is a count,
+		 * push TCC into stack
+		 */
+		emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
 
-static u8 tail_call_reg(struct jit_ctx *ctx)
-{
-	if (seen_call(ctx))
-		return TCC_SAVED;
+		/* Push the address of TCC into the REG_TCC */
+		emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
+
+		emit_uncond_jmp(ctx, 2);
+
+		/* If REG_TCC > MAX_TAIL_CALL_CNT, the value in REG_TCC is an address,
+		 * push TCC_ptr into stack
+		 */
+		emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
+	} else {
+		*store_offset -= sizeof(long);
+		emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
+	}
 
-	return REG_TCC;
+	/*push TCC_ptr into stack*/
+	*store_offset -= sizeof(long);
+	emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
 }
 
 /*
@@ -90,6 +110,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
  *                            |           $s4           |
  *                            +-------------------------+
  *                            |           $s5           |
+ *                            +-------------------------+
+ *                            |           tcc           |
+ *                            +-------------------------+
+ *                            |           tcc_ptr       |
  *                            +-------------------------+ <--BPF_REG_FP
  *                            |  prog->aux->stack_depth |
  *                            |        (optional)       |
@@ -100,12 +124,17 @@ static void build_prologue(struct jit_ctx *ctx)
 {
 	int i;
 	int stack_adjust = 0, store_offset, bpf_stack_adjust;
+	const struct bpf_prog *prog = ctx->prog;
+	const bool is_main_prog = !bpf_is_subprog(prog);
 
 	bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
 
-	/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
+	/* To store ra, fp, s0, s1, s2, s3, s4, s5 */
 	stack_adjust += sizeof(long) * 8;
 
+	/* To store tcc and tcc_ptr */
+	stack_adjust += sizeof(long) * 2;
+
 	stack_adjust = round_up(stack_adjust, 16);
 	stack_adjust += bpf_stack_adjust;
 
@@ -114,11 +143,12 @@ static void build_prologue(struct jit_ctx *ctx)
 		emit_insn(ctx, nop);
 
 	/*
-	 * First instruction initializes the tail call count (TCC).
-	 * On tail call we skip this instruction, and the TCC is
+	 * First instruction initializes the tail call count (TCC) register
+	 * to zero. On tail call we skip this instruction, and the TCC is
 	 * passed in REG_TCC from the caller.
 	 */
-	emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+	if (is_main_prog)
+		emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
 
 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
 
@@ -146,20 +176,13 @@ static void build_prologue(struct jit_ctx *ctx)
 	store_offset -= sizeof(long);
 	emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
 
+	prepare_bpf_tail_call_cnt(ctx, &store_offset);
+
 	emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
 
 	if (bpf_stack_adjust)
 		emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
 
-	/*
-	 * Program contains calls and tail calls, so REG_TCC need
-	 * to be saved across calls.
-	 */
-	if (seen_tail_call(ctx) && seen_call(ctx))
-		move_reg(ctx, TCC_SAVED, REG_TCC);
-	else
-		emit_insn(ctx, nop);
-
 	ctx->stack_size = stack_adjust;
 }
 
@@ -192,6 +215,16 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
 	load_offset -= sizeof(long);
 	emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
 
+	/*
+	 *  When push into the stack, follow the order of tcc then tcc_ptr.
+	 *  When pop from the stack, first pop tcc_ptr followed by tcc
+	 */
+	load_offset -= 2*sizeof(long);
+	emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
+
+	load_offset += sizeof(long);
+	emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
+
 	emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
 
 	if (!is_tail_call) {
@@ -204,7 +237,7 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
 		 * Call the next bpf prog and skip the first instruction
 		 * of TCC initialization.
 		 */
-		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1);
+		emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 6);
 	}
 }
 
@@ -226,7 +259,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
 static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
 {
 	int off;
-	u8 tcc = tail_call_reg(ctx);
+	int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
 	u8 a1 = LOONGARCH_GPR_A1;
 	u8 a2 = LOONGARCH_GPR_A2;
 	u8 t1 = LOONGARCH_GPR_T1;
@@ -255,11 +288,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
 		goto toofar;
 
 	/*
-	 * if (--TCC < 0)
-	 *	 goto out;
+	 * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
+	 *      goto out;
 	 */
-	emit_insn(ctx, addid, REG_TCC, tcc, -1);
-	if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
+	emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
+	emit_insn(ctx, ldd, t3, REG_TCC, 0);
+	emit_insn(ctx, addid, t3, t3, 1);
+	emit_insn(ctx, std, t3, REG_TCC, 0);
+	emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+	if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
 		goto toofar;
 
 	/*
@@ -480,6 +517,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
 	const s16 off = insn->off;
 	const s32 imm = insn->imm;
 	const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32;
+	int tcc_ptr_off;
 
 	switch (code) {
 	/* dst = src */
@@ -906,12 +944,16 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
 
 	/* function call */
 	case BPF_JMP | BPF_CALL:
-		mark_call(ctx);
 		ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
 					    &func_addr, &func_addr_fixed);
 		if (ret < 0)
 			return ret;
 
+		if (insn->src_reg == BPF_PSEUDO_CALL) {
+			tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
+			emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
+		}
+
 		move_addr(ctx, t1, func_addr);
 		emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
 
@@ -922,7 +964,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
 
 	/* tail call */
 	case BPF_JMP | BPF_TAIL_CALL:
-		mark_tail_call(ctx);
 		if (emit_bpf_tail_call(ctx, i) < 0)
 			return -EINVAL;
 		break;
@@ -1597,7 +1638,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 {
 	int i;
 	int stack_size = 0, nargs = 0;
-	int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off;
+	int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off, tcc_ptr_off;
 	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
 	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
 	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
@@ -1633,6 +1674,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 	 *
 	 * FP - sreg_off    [ callee saved reg  ]
 	 *
+	 * FP - tcc_ptr_off [ tail_call_cnt_ptr ]
 	 */
 
 	if (m->nr_args > LOONGARCH_MAX_REG_ARGS)
@@ -1675,6 +1717,12 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 	stack_size += 8;
 	sreg_off = stack_size;
 
+	/* room of trampoline frame to store tail_call_cnt_ptr */
+	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
+		stack_size += 8;
+		tcc_ptr_off = stack_size;
+	}
+
 	stack_size = round_up(stack_size, 16);
 
 	if (!is_struct_ops) {
@@ -1703,6 +1751,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 		emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size);
 	}
 
+	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+		emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
 	/* callee saved register S1 to pass start time */
 	emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
 
@@ -1750,6 +1801,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 
 	if (flags & BPF_TRAMP_F_CALL_ORIG) {
 		restore_args(ctx, m->nr_args, args_off);
+
+		if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+			emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
 		ret = emit_call(ctx, (const u64)orig_call);
 		if (ret)
 			goto out;
@@ -1791,6 +1846,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 
 	emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
 
+	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+		emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
 	if (!is_struct_ops) {
 		/* trampoline called from function entry */
 		emit_insn(ctx, ldd, LOONGARCH_GPR_T0, LOONGARCH_GPR_SP, stack_size - 8);
-- 
2.43.0





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux