When a reuseport socket group is using a BPF filter to distribute
the packets among the sockets, we don't need to compute any hash
value, but the current reuseport_select_sock() requires the
caller to compute such hash in advance.

This patch reworks reuseport_select_sock() to compute the hash value
only if needed - missing or failing BPF filter. Since different
hash functions have different argument types - ipv4 addresses vs ipv6
ones - to avoid over-complicate the interface, reuseport_select_sock()
is now a macro.

Additionally, the sk_reuseport test is move inside reuseport_select_sock,
to avoid some code duplication.

Overall this gives small but measurable performance improvement
under UDP flood while using SO_REUSEPORT + BPF.

Signed-off-by: Paolo Abeni <pab...@redhat.com>
---
 include/net/sock_reuseport.h | 32 ++++++++++++++++++++++++++++----
 net/core/sock_reuseport.c    | 34 +++++++++++++++-------------------
 net/ipv4/inet_hashtables.c   | 28 ++++++++++------------------
 net/ipv4/udp.c               | 30 ++++++++++++------------------
 net/ipv6/inet6_hashtables.c  | 28 ++++++++++------------------
 net/ipv6/udp.c               | 31 ++++++++++++-------------------
 6 files changed, 87 insertions(+), 96 deletions(-)

diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h
index 0054b3a9b923..e7d71d22dca7 100644
--- a/include/net/sock_reuseport.h
+++ b/include/net/sock_reuseport.h
@@ -16,13 +16,37 @@ struct sock_reuseport {
        struct sock             *socks[0];      /* array of sock pointers */
 };
 
+struct reuseport_info {
+       struct sock_reuseport *reuse;
+       struct sock *sk;
+       u16 socks;
+};
+
 extern int reuseport_alloc(struct sock *sk);
 extern int reuseport_add_sock(struct sock *sk, struct sock *sk2);
 extern void reuseport_detach_sock(struct sock *sk);
-extern struct sock *reuseport_select_sock(struct sock *sk,
-                                         u32 hash,
-                                         struct sk_buff *skb,
-                                         int hdr_len);
+bool __reuseport_get_info(struct sock *sk, struct sk_buff *skb, int hdr_len,
+                         struct reuseport_info *info);
+static inline struct sock *__reuseport_select_sock(struct reuseport_info *info,
+                                                  u32 hash)
+{
+       return info->reuse->socks[reciprocal_scale(hash, info->socks)];
+}
+
+#define reuseport_select_sock(sk, skb, net, hlen, fn, saddr, sport, daddr, 
dport) \
+({                                                                           \
+       struct reuseport_info info;                                           \
+       info.sk = NULL;                                                       \
+       if (sk->sk_reuseport) {                                               \
+               rcu_read_lock();                                              \
+               if (__reuseport_get_info(sk, skb, hlen, &info) && !info.sk)   \
+                       info.sk = __reuseport_select_sock(&info,              \
+                                        fn(net, daddr, hnum, saddr, sport)); \
+               rcu_read_unlock();                                            \
+       }                                                                     \
+       info.sk;                                                              \
+})
+
 extern struct bpf_prog *reuseport_attach_prog(struct sock *sk,
                                              struct bpf_prog *prog);
 
diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
index c5bb52bc73a1..8d66e66239a2 100644
--- a/net/core/sock_reuseport.c
+++ b/net/core/sock_reuseport.c
@@ -201,31 +201,30 @@ static struct sock *run_bpf(struct sock_reuseport *reuse, 
u16 socks,
 }
 
 /**
- *  reuseport_select_sock - Select a socket from an SO_REUSEPORT group.
+ *  __reuseport_get_info - Retrieve information for reuseport socket selection
  *  @sk: First socket in the group.
- *  @hash: When no BPF filter is available, use this hash to select.
  *  @skb: skb to run through BPF filter.
  *  @hdr_len: BPF filter expects skb data pointer at payload data.  If
  *    the skb does not yet point at the payload, this parameter represents
  *    how far the pointer needs to advance to reach the payload.
- *  Returns a socket that should receive the packet (or NULL on error).
+ *  @info: reuseport information, filled only if return value is true
+ *  Returns true if @sk is a reuseport socket, and fill @info accordingly.
+ *  if @info.sk is NULL, the caller must retrieve the selected reuseport socket
+ *  calling __reuseport_select_sock(). The caller must hold the RCU lock.
  */
-struct sock *reuseport_select_sock(struct sock *sk,
-                                  u32 hash,
-                                  struct sk_buff *skb,
-                                  int hdr_len)
+bool __reuseport_get_info(struct sock *sk, struct sk_buff *skb, int hdr_len,
+                         struct reuseport_info *info)
 {
        struct sock_reuseport *reuse;
        struct bpf_prog *prog;
-       struct sock *sk2 = NULL;
        u16 socks;
 
-       rcu_read_lock();
+       info->sk = NULL;
        reuse = rcu_dereference(sk->sk_reuseport_cb);
 
        /* if memory allocation failed or add call is not yet complete */
        if (!reuse)
-               goto out;
+               return false;
 
        prog = rcu_dereference(reuse->prog);
        socks = READ_ONCE(reuse->num_socks);
@@ -234,18 +233,15 @@ struct sock *reuseport_select_sock(struct sock *sk,
                smp_rmb();
 
                if (prog && skb)
-                       sk2 = run_bpf(reuse, socks, prog, skb, hdr_len);
+                       info->sk = run_bpf(reuse, socks, prog, skb, hdr_len);
 
-               /* no bpf or invalid bpf result: fall back to hash usage */
-               if (!sk2)
-                       sk2 = reuse->socks[reciprocal_scale(hash, socks)];
+               info->reuse = reuse;
+               info->socks = socks;
+               return true;
        }
-
-out:
-       rcu_read_unlock();
-       return sk2;
+       return false;
 }
-EXPORT_SYMBOL(reuseport_select_sock);
+EXPORT_SYMBOL(__reuseport_get_info);
 
 struct bpf_prog *
 reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index f6f58108b4c5..eed48aab05f5 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -273,21 +273,17 @@ static struct sock *inet_lhash2_lookup(struct net *net,
        struct inet_connection_sock *icsk;
        struct sock *sk, *result = NULL;
        int score, hiscore = 0;
-       u32 phash = 0;
 
        inet_lhash2_for_each_icsk_rcu(icsk, &ilb2->head) {
                sk = (struct sock *)icsk;
                score = compute_score(sk, net, hnum, daddr,
                                      dif, sdif, exact_dif);
                if (score > hiscore) {
-                       if (sk->sk_reuseport) {
-                               phash = inet_ehashfn(net, daddr, hnum,
-                                                    saddr, sport);
-                               result = reuseport_select_sock(sk, phash,
-                                                              skb, doff);
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net, doff,
+                                                      inet_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        result = sk;
                        hiscore = score;
                }
@@ -310,7 +306,6 @@ struct sock *__inet_lookup_listener(struct net *net,
        struct sock *sk, *result = NULL;
        int score, hiscore = 0;
        unsigned int hash2;
-       u32 phash = 0;
 
        if (ilb->count <= 10 || !hashinfo->lhash2)
                goto port_lookup;
@@ -346,14 +341,11 @@ struct sock *__inet_lookup_listener(struct net *net,
                score = compute_score(sk, net, hnum, daddr,
                                      dif, sdif, exact_dif);
                if (score > hiscore) {
-                       if (sk->sk_reuseport) {
-                               phash = inet_ehashfn(net, daddr, hnum,
-                                                    saddr, sport);
-                               result = reuseport_select_sock(sk, phash,
-                                                              skb, doff);
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net, doff,
+                                                      inet_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        result = sk;
                        hiscore = score;
                }
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index e9c0d1e1772e..8072755bb5fc 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -440,7 +440,6 @@ static struct sock *udp4_lib_lookup2(struct net *net,
 {
        struct sock *sk, *result;
        int score, badness;
-       u32 hash = 0;
 
        result = NULL;
        badness = 0;
@@ -448,14 +447,12 @@ static struct sock *udp4_lib_lookup2(struct net *net,
                score = compute_score(sk, net, saddr, sport,
                                      daddr, hnum, dif, sdif, exact_dif);
                if (score > badness) {
-                       if (sk->sk_reuseport) {
-                               hash = udp_ehashfn(net, daddr, hnum,
-                                                  saddr, sport);
-                               result = reuseport_select_sock(sk, hash, skb,
-                                                       sizeof(struct udphdr));
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net,
+                                                      sizeof(struct udphdr),
+                                                      udp_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        badness = score;
                        result = sk;
                }
@@ -476,7 +473,6 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 
saddr,
        struct udp_hslot *hslot2, *hslot = &udptable->hash[slot];
        bool exact_dif = udp_lib_exact_dif_match(net, skb);
        int score, badness;
-       u32 hash = 0;
 
        if (hslot->count > 10) {
                hash2 = ipv4_portaddr_hash(net, daddr, hnum);
@@ -513,14 +509,12 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 
saddr,
                score = compute_score(sk, net, saddr, sport,
                                      daddr, hnum, dif, sdif, exact_dif);
                if (score > badness) {
-                       if (sk->sk_reuseport) {
-                               hash = udp_ehashfn(net, daddr, hnum,
-                                                  saddr, sport);
-                               result = reuseport_select_sock(sk, hash, skb,
-                                                       sizeof(struct udphdr));
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net,
+                                                      sizeof(struct udphdr),
+                                                      udp_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        result = sk;
                        badness = score;
                }
diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c
index 2febe26de6a1..f6167e647672 100644
--- a/net/ipv6/inet6_hashtables.c
+++ b/net/ipv6/inet6_hashtables.c
@@ -136,21 +136,17 @@ static struct sock *inet6_lhash2_lookup(struct net *net,
        struct inet_connection_sock *icsk;
        struct sock *sk, *result = NULL;
        int score, hiscore = 0;
-       u32 phash = 0;
 
        inet_lhash2_for_each_icsk_rcu(icsk, &ilb2->head) {
                sk = (struct sock *)icsk;
                score = compute_score(sk, net, hnum, daddr, dif, sdif,
                                      exact_dif);
                if (score > hiscore) {
-                       if (sk->sk_reuseport) {
-                               phash = inet6_ehashfn(net, daddr, hnum,
-                                                     saddr, sport);
-                               result = reuseport_select_sock(sk, phash,
-                                                              skb, doff);
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net, doff,
+                                                      inet6_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        result = sk;
                        hiscore = score;
                }
@@ -173,7 +169,6 @@ struct sock *inet6_lookup_listener(struct net *net,
        struct sock *sk, *result = NULL;
        int score, hiscore = 0;
        unsigned int hash2;
-       u32 phash = 0;
 
        if (ilb->count <= 10 || !hashinfo->lhash2)
                goto port_lookup;
@@ -208,14 +203,11 @@ struct sock *inet6_lookup_listener(struct net *net,
        sk_for_each(sk, &ilb->head) {
                score = compute_score(sk, net, hnum, daddr, dif, sdif, 
exact_dif);
                if (score > hiscore) {
-                       if (sk->sk_reuseport) {
-                               phash = inet6_ehashfn(net, daddr, hnum,
-                                                     saddr, sport);
-                               result = reuseport_select_sock(sk, phash,
-                                                              skb, doff);
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net, doff,
+                                                      inet6_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        result = sk;
                        hiscore = score;
                }
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index eecf9f0faf29..936c2a5c7147 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -169,7 +169,6 @@ static struct sock *udp6_lib_lookup2(struct net *net,
 {
        struct sock *sk, *result;
        int score, badness;
-       u32 hash = 0;
 
        result = NULL;
        badness = -1;
@@ -177,15 +176,12 @@ static struct sock *udp6_lib_lookup2(struct net *net,
                score = compute_score(sk, net, saddr, sport,
                                      daddr, hnum, dif, sdif, exact_dif);
                if (score > badness) {
-                       if (sk->sk_reuseport) {
-                               hash = udp6_ehashfn(net, daddr, hnum,
-                                                   saddr, sport);
-
-                               result = reuseport_select_sock(sk, hash, skb,
-                                                       sizeof(struct udphdr));
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net,
+                                                      sizeof(struct udphdr),
+                                                      udp6_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        result = sk;
                        badness = score;
                }
@@ -206,7 +202,6 @@ struct sock *__udp6_lib_lookup(struct net *net,
        struct udp_hslot *hslot2, *hslot = &udptable->hash[slot];
        bool exact_dif = udp6_lib_exact_dif_match(net, skb);
        int score, badness;
-       u32 hash = 0;
 
        if (hslot->count > 10) {
                hash2 = ipv6_portaddr_hash(net, daddr, hnum);
@@ -244,14 +239,12 @@ struct sock *__udp6_lib_lookup(struct net *net,
                score = compute_score(sk, net, saddr, sport, daddr, hnum, dif,
                                      sdif, exact_dif);
                if (score > badness) {
-                       if (sk->sk_reuseport) {
-                               hash = udp6_ehashfn(net, daddr, hnum,
-                                                   saddr, sport);
-                               result = reuseport_select_sock(sk, hash, skb,
-                                                       sizeof(struct udphdr));
-                               if (result)
-                                       return result;
-                       }
+                       result = reuseport_select_sock(sk, skb, net,
+                                                      sizeof(struct udphdr),
+                                                      udp6_ehashfn, saddr,
+                                                      sport, daddr, hnum);
+                       if (result)
+                               return result;
                        result = sk;
                        badness = score;
                }
-- 
2.14.3

Reply via email to