Re: [PATCH bpf-next] bpf: Simplify bounds refinement from s32

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

 



On Thu, Jul 24, 2025 at 02:49:47PM -0700, Eduard Zingerman wrote:
> On Thu, 2025-07-24 at 19:42 +0200, Paul Chaignon wrote:
> > During the bounds refinement, we improve the precision of various ranges
> > by looking at other ranges. Among others, we improve the following in
> > this order (other things happen between 1 and 2):
> > 
> >   1. Improve u32 from s32 in __reg32_deduce_bounds.
> >   2. Improve s/u64 from u32 in __reg_deduce_mixed_bounds.
> >   3. Improve s/u64 from s32 in __reg_deduce_mixed_bounds.
> > 
> > In particular, if the s32 range forms a valid u32 range, we will use it
> > to improve the u32 range in __reg32_deduce_bounds. In
> > __reg_deduce_mixed_bounds, under the same condition, we will use the s32
> > range to improve the s/u64 ranges.
> > 
> > If at (1) we were able to learn from s32 to improve u32, we'll then be
> > able to use that in (2) to improve s/u64. Hence, as (3) happens under
> > the same precondition as (1), it won't improve s/u64 ranges further than
> > (1)+(2) did. Thus, we can get rid of (3).
> > 
> > In addition to the extensive suite of selftests for bounds refinement,
> > this patch was also tested with the Agni formal verification tool [1].
> > 
> > Link: https://github.com/bpfverif/agni [1]
> > Signed-off-by: Paul Chaignon <paul.chaignon@xxxxxxxxx>
> > ---
> 
> So, the argument appears to be as follows:
> 
> Under precondition `(u32)reg->s32_min <= (u32)reg->s32_max`
> __reg32_deduce_bounds produces:
> 
>   reg->u32_min = max_t(u32, reg->s32_min, reg->u32_min);
>   reg->u32_max = min_t(u32, reg->s32_max, reg->u32_max);
> 
> And then first part of __reg_deduce_mixed_bounds assigns:
> 
>   a. reg->umin umax= (reg->umin & ~0xffffffffULL) | max_t(u32, reg->s32_min, reg->u32_min);
>   b. reg->umax umin= (reg->umax & ~0xffffffffULL) | min_t(u32, reg->s32_max, reg->u32_max);
> 
> And then second part of __reg_deduce_mixed_bounds assigns:
> 
>   c. reg->umin umax= (reg->umin & ~0xffffffffULL) | (u32)reg->s32_min;
>   d. reg->umax umin= (reg->umax & ~0xffffffffULL) | (u32)reg->s32_max;
> 
> But assignment (c) is a noop because:
> 
>    max_t(u32, reg->s32_min, reg->u32_min) >= (u32)reg->s32_min
> 
> Hence RHS(a) >= RHS(c) and umin= does nothing.
> 
> Also assignment (d) is a noop because:
> 
>   min_t(u32, reg->s32_max, reg->u32_max) <= (u32)reg->s32_max
> 
> Hence RHS(b) <= RHS(d) and umin= does nothing.
> 
> Plus the same reasoning for the part dealing with reg->s{min,max}_value:
> 
>   e. reg->smin_value smax= (reg->smin_value & ~0xffffffffULL) | max_t(u32, reg->s32_min_value, reg->u32_min_value);
>   f. reg->smax_value smin= (reg->smax_value & ~0xffffffffULL) | min_t(u32, reg->s32_max_value, reg->u32_max_value);
> 
>     vs
> 
>   g. reg->smin_value smax= (reg->smin_value & ~0xffffffffULL) | (u32)reg->s32_min_value;
>   h. reg->smax_value smin= (reg->smax_value & ~0xffffffffULL) | (u32)reg->s32_max_value;
> 
>     RHS(e) >= RHS(g) and RHS(f) <= RHS(h), hence smax=,smin= do nothing.
> 
> This appears to be correct.
> 
> Shung-Hsi, wdyt?

Agree with the reasoning above, it looks solid.

Beside going through the reasoning, I also played with CBMC a bit to
double check that as far as a single run of __reg_deduce_bounds() is
concerned (and that the register state matches certain handwavy
expectations), the change indeed still preserve the original behavior.

Reviewed-by: Shung-Hsi Yu <shung-hsi.yu@xxxxxxxx>

Simplification of bound deduction logic! \o/

#include <stdint.h>
#include <limits.h>
#include <stdbool.h>
#include <assert.h>

// Define Linux kernel types
typedef uint64_t u64;
typedef int64_t s64;
typedef uint32_t u32;
typedef int32_t s32;
typedef uint8_t u8;
typedef int8_t s8;
typedef uint16_t u16;
typedef int16_t s16;

// Define limits
#define S8_MIN  INT8_MIN
#define S8_MAX  INT8_MAX
#define S16_MIN INT16_MIN
#define S16_MAX INT16_MAX
#define S32_MIN INT32_MIN
#define S32_MAX INT32_MAX
#define U32_MAX UINT32_MAX
#define S64_MIN INT64_MIN
#define S64_MAX INT64_MAX
#define U64_MAX UINT64_MAX

/* Crude approximation of min_t() and max_t() */
#define min_t(type, x, y) (((type) (x)) < ((type) (y)) ? ((type) (x)) : ((type) (y)))
#define max_t(type, x, y) (((type) (x)) > ((type) (y)) ? ((type) (x)) : ((type) (y)))

// Simplified version of bpf_reg_state with only field needed by
// coerce_reg_to_size_sx
struct bpf_reg_state {
	s64 smin_value;
	s64 smax_value;
	u64 umin_value;
	u64 umax_value;
	s32 s32_min_value;
	s32 s32_max_value;
	u32 u32_min_value;
	u32 u32_max_value;
};

static void __reg32_deduce_bounds(struct bpf_reg_state *reg)
{
	if ((reg->umin_value >> 32) == (reg->umax_value >> 32)) {
		reg->u32_min_value = max_t(u32, reg->u32_min_value, (u32)reg->umin_value);
		reg->u32_max_value = min_t(u32, reg->u32_max_value, (u32)reg->umax_value);

		if ((s32)reg->umin_value <= (s32)reg->umax_value) {
			reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->umin_value);
			reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->umax_value);
		}
	}
	if ((reg->smin_value >> 32) == (reg->smax_value >> 32)) {
		if ((u32)reg->smin_value <= (u32)reg->smax_value) {
			reg->u32_min_value = max_t(u32, reg->u32_min_value, (u32)reg->smin_value);
			reg->u32_max_value = min_t(u32, reg->u32_max_value, (u32)reg->smax_value);
		}
		if ((s32)reg->smin_value <= (s32)reg->smax_value) {
			reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->smin_value);
			reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->smax_value);
		}
	}
	if ((u32)(reg->umin_value >> 32) + 1 == (u32)(reg->umax_value >> 32) &&
	    (s32)reg->umin_value < 0 && (s32)reg->umax_value >= 0) {
		reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->umin_value);
		reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->umax_value);
	}
	if ((u32)(reg->smin_value >> 32) + 1 == (u32)(reg->smax_value >> 32) &&
	    (s32)reg->smin_value < 0 && (s32)reg->smax_value >= 0) {
		reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->smin_value);
		reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->smax_value);
	}
	if ((s32)reg->u32_min_value <= (s32)reg->u32_max_value) {
		reg->s32_min_value = max_t(s32, reg->s32_min_value, reg->u32_min_value);
		reg->s32_max_value = min_t(s32, reg->s32_max_value, reg->u32_max_value);
	}
	if ((u32)reg->s32_min_value <= (u32)reg->s32_max_value) {
		reg->u32_min_value = max_t(u32, reg->s32_min_value, reg->u32_min_value);
		reg->u32_max_value = min_t(u32, reg->s32_max_value, reg->u32_max_value);
	}
}

static void __reg64_deduce_bounds(struct bpf_reg_state *reg)
{
	if ((s64)reg->umin_value <= (s64)reg->umax_value) {
		reg->smin_value = max_t(s64, reg->smin_value, reg->umin_value);
		reg->smax_value = min_t(s64, reg->smax_value, reg->umax_value);
	}
	if ((u64)reg->smin_value <= (u64)reg->smax_value) {
		reg->umin_value = max_t(u64, reg->smin_value, reg->umin_value);
		reg->umax_value = min_t(u64, reg->smax_value, reg->umax_value);
	}
}

static void __reg_deduce_mixed_bounds_old(struct bpf_reg_state *reg)
{
	u64 new_umin, new_umax;
	s64 new_smin, new_smax;

	new_umin = (reg->umin_value & ~0xffffffffULL) | reg->u32_min_value;
	new_umax = (reg->umax_value & ~0xffffffffULL) | reg->u32_max_value;
	reg->umin_value = max_t(u64, reg->umin_value, new_umin);
	reg->umax_value = min_t(u64, reg->umax_value, new_umax);
	new_smin = (reg->smin_value & ~0xffffffffULL) | reg->u32_min_value;
	new_smax = (reg->smax_value & ~0xffffffffULL) | reg->u32_max_value;
	reg->smin_value = max_t(s64, reg->smin_value, new_smin);
	reg->smax_value = min_t(s64, reg->smax_value, new_smax);

	if ((u32)reg->s32_min_value <= (u32)reg->s32_max_value) {
		new_umin = (reg->umin_value & ~0xffffffffULL) | (u32)reg->s32_min_value;
		new_umax = (reg->umax_value & ~0xffffffffULL) | (u32)reg->s32_max_value;
		reg->umin_value = max_t(u64, reg->umin_value, new_umin);
		reg->umax_value = min_t(u64, reg->umax_value, new_umax);
		new_smin = (reg->smin_value & ~0xffffffffULL) | (u32)reg->s32_min_value;
		new_smax = (reg->smax_value & ~0xffffffffULL) | (u32)reg->s32_max_value;
		reg->smin_value = max_t(s64, reg->smin_value, new_smin);
		reg->smax_value = min_t(s64, reg->smax_value, new_smax);
	}
	if (reg->s32_min_value >= 0 && reg->smin_value >= S32_MIN && reg->smax_value <= S32_MAX) {
		reg->smin_value = reg->s32_min_value;
		reg->smax_value = reg->s32_max_value;
		reg->umin_value = reg->s32_min_value;
		reg->umax_value = reg->s32_max_value;
		/* var_off update with tnum_intersect() removed, was the last
		 * step, so shouldn't make a difference
		 */
	}
}

static void __reg_deduce_mixed_bounds_new(struct bpf_reg_state *reg)
{
	u64 new_umin, new_umax;
	s64 new_smin, new_smax;

	new_umin = (reg->umin_value & ~0xffffffffULL) | reg->u32_min_value;
	new_umax = (reg->umax_value & ~0xffffffffULL) | reg->u32_max_value;
	reg->umin_value = max_t(u64, reg->umin_value, new_umin);
	reg->umax_value = min_t(u64, reg->umax_value, new_umax);
	new_smin = (reg->smin_value & ~0xffffffffULL) | reg->u32_min_value;
	new_smax = (reg->smax_value & ~0xffffffffULL) | reg->u32_max_value;
	reg->smin_value = max_t(s64, reg->smin_value, new_smin);
	reg->smax_value = min_t(s64, reg->smax_value, new_smax);

	/* s32 -> u/s64 tightening removed */

	if (reg->s32_min_value >= 0 && reg->smin_value >= S32_MIN && reg->smax_value <= S32_MAX) {
		reg->smin_value = reg->s32_min_value;
		reg->smax_value = reg->s32_max_value;
		reg->umin_value = reg->s32_min_value;
		reg->umax_value = reg->s32_max_value;
		/* var_off update with tnum_intersect() removed, was the last
		 * step, so shouldn't make a difference
		 */
	}
}

static void __reg_deduce_bounds_old(struct bpf_reg_state *reg)
{
	__reg32_deduce_bounds(reg);
	__reg64_deduce_bounds(reg);
	__reg_deduce_mixed_bounds_old(reg);
}

static void __reg_deduce_bounds_new(struct bpf_reg_state *reg)
{
	__reg32_deduce_bounds(reg);
	__reg64_deduce_bounds(reg);
	__reg_deduce_mixed_bounds_new(reg);
}

/* helper function to initialize 'struct bpf_reg_state' */
static struct bpf_reg_state __bpf_reg_state_input(void)
{
	struct bpf_reg_state reg;
	reg.smin_value = nondet_long_long_input();
	reg.smax_value = nondet_long_long_input();
	reg.umin_value = nondet_unsigned_long_long_input();
	reg.umax_value = nondet_unsigned_long_long_input();
	reg.s32_min_value = nondet_int_input();
	reg.s32_max_value = nondet_int_input();
	reg.u32_min_value = nondet_unsigned_int_input();
	reg.u32_max_value = nondet_unsigned_int_input();
	return reg;
}

/* helper function to ensure 'struct bpf_reg_state' is in a proper state */
static bool valid_bpf_reg_state(struct bpf_reg_state *reg)
{
	bool ret = true;
	/* Ensure maximum >= minimum for all ranges */
	ret &= reg->umin_value <= reg->umax_value;
	ret &= reg->smin_value <= reg->smax_value;
	ret &= reg->u32_min_value <= reg->u32_max_value;
	ret &= reg->s32_min_value <= reg->s32_max_value;
	/* Ensure 64-bit bounds are consistent with 32-bit bounds */
	ret &= reg->umin_value <= (u64)reg->u32_max_value;
	ret &= reg->umax_value >= (u64)reg->u32_min_value;
	ret &= (s64)reg->smin_value <= (s64)reg->s32_max_value;
	ret &= (s64)reg->smax_value >= (s64)reg->s32_min_value;
	return ret;
}

/* helper function to check whether 'struct bpf_reg_state' contains certain value */
static bool val_in_reg(struct bpf_reg_state *reg, u64 val)
{
	bool ret = true;
	ret &= reg->umin_value <= val;
	ret &= val <= reg->umax_value;
	ret &= reg->smin_value <= (s64)val;
	ret &= (s64)val <= reg->smax_value;
	ret &= reg->u32_min_value <= (u32)val;
	ret &= (u32)val <= reg->u32_max_value;
	ret &= reg->s32_min_value <= (s32)val;
	ret &= (s32)val <= reg->s32_max_value;
	return ret;
}

void main(void)
{
	/* ------------ Assumptions and Setup ------------ */

	/* Input data structure that represents current knowledge of the possible
	 * values in a register, as well as some possible value 'x', which could be
	 * any value that is in the register right now.
	 */
	struct bpf_reg_state reg = __bpf_reg_state_input();
	u64 x = nondet_unsigned_long_long_input();
	__CPROVER_assume(valid_bpf_reg_state(&reg));
	__CPROVER_assume(val_in_reg(&reg, x));

	/* ------------- Operation to Check -------------- */
	/* Data structure to store the new output */
	struct bpf_reg_state new_reg;
	/* Clone the register state since __reg_deduce_bounds() modifies it */
	new_reg = reg;

	__reg_deduce_bounds_old(&reg);
	__reg_deduce_bounds_new(&new_reg);

	/* -------------- Property Checking -------------- */
	assert(new_reg == reg);
}

[Index of Archives]     [Linux Samsung SoC]     [Linux Rockchip SoC]     [Linux Actions SoC]     [Linux for Synopsys ARC Processors]     [Linux NFS]     [Linux NILFS]     [Linux USB Devel]     [Video for Linux]     [Linux Audio Users]     [Yosemite News]     [Linux Kernel]     [Linux SCSI]


  Powered by Linux