re:[PATCH v2 2/2] LoongArch: BPF: Fix tailcall hierarchy

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 






在 2025-07-16星期三的 10:28 +0800,Hengqi Chen写道:
> On Tue, Jul 8, 2025 at 3:19 PM Haoran Jiang <
> jianghaoran@xxxxxxxxxx
> > wrote:
> > In specific use cases combining tailcalls and BPF-to-BPF calls,
> > MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt
> > back-propagation from callee to caller。This patch fixes this
> > tailcall issue caused by abusing the tailcall in bpf2bpf
> > feature
> > on LoongArch like the way of "bpf, x64: Fix tailcall
> > hierarchy".
> > 
> > push tail_call_cnt_ptr and tail_call_cnt into the stack,
> > tail_call_cnt_ptr is passed between tailcall and bpf2bpf,
> > uses tail_call_cnt_ptr to increment tail_call_cnt.
> > 
> > Fixes: bb035ef0cc91 ("LoongArch: BPF: Support mixing bpf2bpf
> > and tailcalls")
> > Fixes: 5dc615520c4d ("LoongArch: Add BPF JIT support")
> > Signed-off-by: Haoran Jiang <
> > jianghaoran@xxxxxxxxxx
> > >
> > ---
> >  arch/loongarch/net/bpf_jit.c | 112 +++++++++++++++++++++----
> > ----------
> >  1 file changed, 68 insertions(+), 44 deletions(-)
> > 
> > diff --git a/arch/loongarch/net/bpf_jit.c
> > b/arch/loongarch/net/bpf_jit.c
> > index 5061bfc978f2..45f804b7c556 100644
> > --- a/arch/loongarch/net/bpf_jit.c
> > +++ b/arch/loongarch/net/bpf_jit.c
> > @@ -7,10 +7,9 @@
> >  #include "bpf_jit.h"
> > 
> >  #define REG_TCC                LOONGARCH_GPR_A6
> > -#define TCC_SAVED      LOONGARCH_GPR_S5
> > 
> > -#define SAVE_RA                BIT(0)
> > -#define SAVE_TCC       BIT(1)
> > +#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack)
> > (round_up(stack, 16) - 80)
> > +
> > 
> >  static const int regmap[] = {
> >         /* return value from in-kernel function, and exit value
> > for eBPF program */
> > @@ -32,32 +31,37 @@ static const int regmap[] = {
> >         [BPF_REG_AX] = LOONGARCH_GPR_T0,
> >  };
> > 
> > -static void mark_call(struct jit_ctx *ctx)
> > +static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int
> > *store_offset)
> 
> Consider adding more comments(e.g. pseudocode) for
> prepare_bpf_tail_call_cnt().
> Assembly is hard to read. (At least for me :))
> 

I will improve the comment information in the next version.
> 
> >  {
> > -       ctx->flags |= SAVE_RA;
> > -}
> > +       const struct bpf_prog *prog = ctx->prog;
> > +       const bool is_main_prog = !bpf_is_subprog(prog);
> > 
> > -static void mark_tail_call(struct jit_ctx *ctx)
> > -{
> > -       ctx->flags |= SAVE_TCC;
> > -}
> > +       if (is_main_prog) {
> > +               emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
> > +               *store_offset -= sizeof(long);
> > 
> > -static bool seen_call(struct jit_ctx *ctx)
> > -{
> > -       return (ctx->flags & SAVE_RA);
> > -}
> > +               emit_tailcall_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
> 
> Why emit_tailcall_jmp() here ? Shouldn't this be emit_cond_jmp() ?

It's more appropriate to use emit_cond_jmp here.
> 
> > 
> > -static bool seen_tail_call(struct jit_ctx *ctx)
> > -{
> > -       return (ctx->flags & SAVE_TCC);
> > -}
> > +               /* If REG_TCC < MAX_TAIL_CALL_CNT, push REG_TCC into stack */
> > +               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
> > 
> > -static u8 tail_call_reg(struct jit_ctx *ctx)
> > -{
> > -       if (seen_call(ctx))
> > -               return TCC_SAVED;
> > +               /* Calculate the pointer to REG_TCC in the stack and assign it to REG_TCC */
> > +               emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
> > +
> > +               emit_uncond_jmp(ctx, 2);
> > +
> > +               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
> > 
> > -       return REG_TCC;
> > +               *store_offset -= sizeof(long);
> > +               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
> > +
> > +       } else {
> > +               *store_offset -= sizeof(long);
> > +               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
> > +
> > +               *store_offset -= sizeof(long);
> > +               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
> > +       }
> >  }
> > 
> >  /*
> > @@ -80,6 +84,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
> >   *                            |           $s4           |
> >   *                            +-------------------------+
> >   *                            |           $s5           |
> > + *                            +-------------------------+
> > + *                            |          reg_tcc        |
> > + *                            +-------------------------+
> > + *                            |          reg_tcc_ptr    |
> >   *                            +-------------------------+ <--BPF_REG_FP
> >   *                            |  prog->aux->stack_depth |
> >   *                            |        (optional)       |
> > @@ -89,21 +97,24 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
> >  static void build_prologue(struct jit_ctx *ctx)
> >  {
> >         int stack_adjust = 0, store_offset, bpf_stack_adjust;
> > +       const struct bpf_prog *prog = ctx->prog;
> > +       const bool is_main_prog = !bpf_is_subprog(prog);
> > 
> >         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
> > 
> > -       /* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
> > -       stack_adjust += sizeof(long) * 8;
> > +       /* To store ra, fp, s0, s1, s2, s3, s4, s5, reg_tcc and reg_tcc_ptr */
> > +       stack_adjust += sizeof(long) * 10;
> > 
> >         stack_adjust = round_up(stack_adjust, 16);
> >         stack_adjust += bpf_stack_adjust;
> > 
> >         /*
> > -        * First instruction initializes the tail call count (TCC).
> > -        * On tail call we skip this instruction, and the TCC is
> > +        * First instruction initializes the tail call count (TCC) register
> > +        * to zero. On tail call we skip this instruction, and the TCC is
> >          * passed in REG_TCC from the caller.
> >          */
> > -       emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
> > +       if (is_main_prog)
> > +               emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
> > 
> >         emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
> > 
> > @@ -131,20 +142,13 @@ static void build_prologue(struct jit_ctx *ctx)
> >         store_offset -= sizeof(long);
> >         emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
> > 
> > +       prepare_bpf_tail_call_cnt(ctx, &store_offset);
> > +
> >         emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
> > 
> >         if (bpf_stack_adjust)
> >                 emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
> > 
> > -       /*
> > -        * Program contains calls and tail calls, so REG_TCC need
> > -        * to be saved across calls.
> > -        */
> > -       if (seen_tail_call(ctx) && seen_call(ctx))
> > -               move_reg(ctx, TCC_SAVED, REG_TCC);
> > -       else
> > -               emit_insn(ctx, nop);
> > -
> >         ctx->stack_size = stack_adjust;
> >  }
> > 
> > @@ -177,6 +181,17 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
> >         load_offset -= sizeof(long);
> >         emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
> > 
> > +       /*
> > +        *  When push into the stack, follow the order of tcc then tcc_ptr.
> > +        *  When pop from the stack, first pop tcc_ptr followed by tcc
> > +        */
> > +       load_offset -= 2*sizeof(long);
> > +       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
> > +
> > +       /* pop tcc_ptr to REG_TCC */
> > +       load_offset += sizeof(long);
> > +       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
> > +
> >         emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
> > 
> >         if (!is_tail_call) {
> > @@ -211,7 +226,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
> >  static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
> >  {
> >         int off;
> > -       u8 tcc = tail_call_reg(ctx);
> > +       int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
> >         u8 a1 = LOONGARCH_GPR_A1;
> >         u8 a2 = LOONGARCH_GPR_A2;
> >         u8 t1 = LOONGARCH_GPR_T1;
> > @@ -240,11 +255,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
> >                 goto toofar;
> > 
> >         /*
> > -        * if (--TCC < 0)
> > -        *       goto out;
> > +        * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
> > +        *      goto out;
> >          */
> > -       emit_insn(ctx, addid, REG_TCC, tcc, -1);
> > -       if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
> > +       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
> > +       emit_insn(ctx, ldd, t3, REG_TCC, 0);
> > +       emit_insn(ctx, addid, t3, t3, 1);
> > +       emit_insn(ctx, std, t3, REG_TCC, 0);
> > +       emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
> > +       if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
> >                 goto toofar;
> > 
> >         /*
> > @@ -465,6 +484,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
> >         const s16 off = insn->off;
> >         const s32 imm = insn->imm;
> >         const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32;
> > +       int tcc_ptr_off;
> > 
> >         switch (code) {
> >         /* dst = src */
> > @@ -891,12 +911,17 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
> > 
> >         /* function call */
> >         case BPF_JMP | BPF_CALL:
> > -               mark_call(ctx);
> >                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
> >                                             &func_addr, &func_addr_fixed);
> >                 if (ret < 0)
> >                         return ret;
> > 
> > +               if (insn->src_reg == BPF_PSEUDO_CALL) {
> > +                       tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
> > +                       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
> > +               }
> > +
> > +
> >                 move_addr(ctx, t1, func_addr);
> >                 emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
> > 
> > @@ -907,7 +932,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
> > 
> >         /* tail call */
> >         case BPF_JMP | BPF_TAIL_CALL:
> > -               mark_tail_call(ctx);
> >                 if (emit_bpf_tail_call(ctx, i) < 0)
> >                         return -EINVAL;
> >                 break;
> > --
> > 2.43.0
> > 





[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux