Require that iter->batch always contains a full bucket snapshot. This invariant is important to avoid skipping or repeating sockets during iteration when combined with the next few patches. Before, there were two cases where a call to bpf_iter_tcp_batch may only capture part of a bucket: 1. When bpf_iter_tcp_realloc_batch() returns -ENOMEM. 2. When more sockets are added to the bucket while calling bpf_iter_tcp_realloc_batch(), making the updated batch size insufficient. In cases where the batch size only covers part of a bucket, it is possible to forget which sockets were already visited, especially if we have to process a bucket in more than two batches. This forces us to choose between repeating or skipping sockets, so don't allow this: 1. Stop iteration and propagate -ENOMEM up to userspace if reallocation fails instead of continuing with a partial batch. 2. Try bpf_iter_tcp_realloc_batch() with GFP_USER just as before, but if we still aren't able to capture the full bucket, call bpf_iter_tcp_realloc_batch() again while holding the bucket lock to guarantee the bucket does not change. On the second attempt use GFP_NOWAIT since we hold onto the spin lock. I did some manual testing to exercise the code paths where GFP_NOWAIT is used and where ERR_PTR(err) is returned. I used the realloc test cases included later in this series to trigger a scenario where a realloc happens inside bpf_iter_tcp_batch and made a small code tweak to force the first realloc attempt to allocate a too-small batch, thus requiring another attempt with GFP_NOWAIT. Some printks showed both reallocs with the tests passing: May 09 18:18:55 crow kernel: resize batch TCP_SEQ_STATE_LISTENING May 09 18:18:55 crow kernel: again GFP_USER May 09 18:18:55 crow kernel: resize batch TCP_SEQ_STATE_LISTENING May 09 18:18:55 crow kernel: again GFP_NOWAIT May 09 18:18:57 crow kernel: resize batch TCP_SEQ_STATE_ESTABLISHED May 09 18:18:57 crow kernel: again GFP_USER May 09 18:18:57 crow kernel: resize batch TCP_SEQ_STATE_ESTABLISHED May 09 18:18:57 crow kernel: again GFP_NOWAIT With this setup, I also forced each of the bpf_iter_tcp_realloc_batch calls to return -ENOMEM to ensure that iteration ends and that the read() in userspace fails. Signed-off-by: Jordan Rife <jordan@xxxxxxxx> Reviewed-by: Kuniyuki Iwashima <kuniyu@xxxxxxxxxx> --- net/ipv4/tcp_ipv4.c | 96 ++++++++++++++++++++++++++++++++------------- 1 file changed, 68 insertions(+), 28 deletions(-) diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 2e40af6aff37..69c976a07434 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -3057,7 +3057,10 @@ static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, if (!new_batch) return -ENOMEM; - bpf_iter_tcp_put_batch(iter); + if (flags != GFP_NOWAIT) + bpf_iter_tcp_put_batch(iter); + + memcpy(new_batch, iter->batch, sizeof(*iter->batch) * iter->end_sk); kvfree(iter->batch); iter->batch = new_batch; iter->max_sk = new_batch_sz; @@ -3066,69 +3069,85 @@ static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, } static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, - struct sock *start_sk) + struct sock **start_sk) { - struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; - struct tcp_iter_state *st = &iter->state; struct hlist_nulls_node *node; unsigned int expected = 1; struct sock *sk; - sock_hold(start_sk); - iter->batch[iter->end_sk++] = start_sk; + sock_hold(*start_sk); + iter->batch[iter->end_sk++] = *start_sk; - sk = sk_nulls_next(start_sk); + sk = sk_nulls_next(*start_sk); + *start_sk = NULL; sk_nulls_for_each_from(sk, node) { if (seq_sk_match(seq, sk)) { if (iter->end_sk < iter->max_sk) { sock_hold(sk); iter->batch[iter->end_sk++] = sk; + } else if (!*start_sk) { + /* Remember where we left off. */ + *start_sk = sk; } expected++; } } - spin_unlock(&hinfo->lhash2[st->bucket].lock); return expected; } static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, - struct sock *start_sk) + struct sock **start_sk) { - struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; - struct tcp_iter_state *st = &iter->state; struct hlist_nulls_node *node; unsigned int expected = 1; struct sock *sk; - sock_hold(start_sk); - iter->batch[iter->end_sk++] = start_sk; + sock_hold(*start_sk); + iter->batch[iter->end_sk++] = *start_sk; - sk = sk_nulls_next(start_sk); + sk = sk_nulls_next(*start_sk); + *start_sk = NULL; sk_nulls_for_each_from(sk, node) { if (seq_sk_match(seq, sk)) { if (iter->end_sk < iter->max_sk) { sock_hold(sk); iter->batch[iter->end_sk++] = sk; + } else if (!*start_sk) { + /* Remember where we left off. */ + *start_sk = sk; } expected++; } } - spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); return expected; } +static void bpf_iter_tcp_unlock_bucket(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + + if (st->state == TCP_SEQ_STATE_LISTENING) + spin_unlock(&hinfo->lhash2[st->bucket].lock); + else + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); +} + static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) { struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; struct tcp_iter_state *st = &iter->state; + int prev_bucket, prev_state; unsigned int expected; - bool resized = false; + int resizes = 0; struct sock *sk; + int err; /* The st->bucket is done. Directly advance to the next * bucket instead of having the tcp_seek_last_pos() to skip @@ -3149,29 +3168,50 @@ static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) /* Get a new batch */ iter->cur_sk = 0; iter->end_sk = 0; - iter->st_bucket_done = false; + iter->st_bucket_done = true; + prev_bucket = st->bucket; + prev_state = st->state; sk = tcp_seek_last_pos(seq); if (!sk) return NULL; /* Done */ + if (st->bucket != prev_bucket || st->state != prev_state) + resizes = 0; + expected = 0; +fill_batch: if (st->state == TCP_SEQ_STATE_LISTENING) - expected = bpf_iter_tcp_listening_batch(seq, sk); + expected += bpf_iter_tcp_listening_batch(seq, &sk); else - expected = bpf_iter_tcp_established_batch(seq, sk); + expected += bpf_iter_tcp_established_batch(seq, &sk); - if (iter->end_sk == expected) { - iter->st_bucket_done = true; - return sk; - } + if (unlikely(resizes <= 1 && iter->end_sk != expected)) { + resizes++; + + if (resizes == 1) { + bpf_iter_tcp_unlock_bucket(seq); - if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2, - GFP_USER)) { - resized = true; - goto again; + err = bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2, + GFP_USER); + if (err) + return ERR_PTR(err); + goto again; + } + + err = bpf_iter_tcp_realloc_batch(iter, expected, GFP_NOWAIT); + if (err) { + bpf_iter_tcp_unlock_bucket(seq); + return ERR_PTR(err); + } + + expected = iter->end_sk; + goto fill_batch; } - return sk; + bpf_iter_tcp_unlock_bucket(seq); + + WARN_ON_ONCE(iter->end_sk != expected); + return iter->batch[0]; } static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos) -- 2.43.0