Run a BPF program before looking up the listening socket, or in case of udp
before looking up any socket, connected or not. The program is allowed to
change the destination address & port we use as keys for the lookup,
providing its return code is BPF_REDIRECT.

This allows us to redirect traffic destined to a set of addresses and ports
without bindings sockets to all the addresses and ports we want to receive
on.

Suggested-by: Marek Majkowski <ma...@cloudflare.com>
Signed-off-by: Jakub Sitnicki <ja...@cloudflare.com>
---
 include/net/inet_hashtables.h | 39 +++++++++++++++++++++++++++++++++++
 net/ipv4/inet_hashtables.c    | 11 ++++++----
 net/ipv4/udp.c                |  1 +
 3 files changed, 47 insertions(+), 4 deletions(-)

diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index babb14136705..7d8b58b2ded0 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -418,4 +418,43 @@ int __inet_hash_connect(struct inet_timewait_death_row 
*death_row,
 
 int inet_hash_connect(struct inet_timewait_death_row *death_row,
                      struct sock *sk);
+
+#ifdef CONFIG_BPF_SYSCALL
+static inline void inet_lookup_run_bpf(struct net *net,
+                                      const __be32 saddr,
+                                      const __be16 sport,
+                                      __be32 *daddr,
+                                      unsigned short *hnum)
+{
+       struct bpf_inet_lookup_kern ctx = {
+               .family = AF_INET,
+               .saddr  = saddr,
+               .sport  = sport,
+               .daddr  = *daddr,
+               .hnum   = *hnum,
+       };
+       struct bpf_prog *prog;
+       int ret = BPF_OK;
+
+       rcu_read_lock();
+       prog = rcu_dereference(net->inet_lookup_prog);
+       if (prog)
+               ret = BPF_PROG_RUN(prog, &ctx);
+       rcu_read_unlock();
+
+       if (ret == BPF_REDIRECT) {
+               *daddr = ctx.daddr;
+               *hnum = ctx.hnum;
+       }
+}
+#else
+static inline void inet_lookup_run_bpf(struct sk_buff *skb,
+                                      const __be32 saddr,
+                                      const __be16 sport,
+                                      __be32 *daddr,
+                                      unsigned short *hnum)
+{
+}
+#endif /* CONFIG_BPF_SYSCALL */
+
 #endif /* _INET_HASHTABLES_H */
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index 942265d65eb3..ae3f1da1b4f6 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -300,24 +300,27 @@ struct sock *__inet_lookup_listener(struct net *net,
                                    const int dif, const int sdif)
 {
        struct inet_listen_hashbucket *ilb2;
+       unsigned short hnum2 = hnum;
        struct sock *result = NULL;
+       __be32 daddr2 = daddr;
        unsigned int hash2;
 
-       hash2 = ipv4_portaddr_hash(net, daddr, hnum);
+       inet_lookup_run_bpf(net, saddr, sport, &daddr2, &hnum2);
+       hash2 = ipv4_portaddr_hash(net, daddr2, hnum2);
        ilb2 = inet_lhash2_bucket(hashinfo, hash2);
 
        result = inet_lhash2_lookup(net, ilb2, skb, doff,
-                                   saddr, sport, daddr, hnum,
+                                   saddr, sport, daddr2, hnum2,
                                    dif, sdif);
        if (result)
                goto done;
 
        /* Lookup lhash2 with INADDR_ANY */
-       hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
+       hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum2);
        ilb2 = inet_lhash2_bucket(hashinfo, hash2);
 
        result = inet_lhash2_lookup(net, ilb2, skb, doff,
-                                   saddr, sport, htonl(INADDR_ANY), hnum,
+                                   saddr, sport, htonl(INADDR_ANY), hnum2,
                                    dif, sdif);
 done:
        if (unlikely(IS_ERR(result)))
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 8fb250ed53d4..c4f4c94525ec 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -467,6 +467,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 
saddr,
        struct udp_hslot *hslot2;
        bool exact_dif = udp_lib_exact_dif_match(net, skb);
 
+       inet_lookup_run_bpf(net, saddr, sport, &daddr, &hnum);
        hash2 = ipv4_portaddr_hash(net, daddr, hnum);
        slot2 = hash2 & udptable->mask;
        hslot2 = &udptable->hash2[slot2];
-- 
2.20.1

Reply via email to