The branch stable/13 has been updated by markj:

URL: 
https://cgit.FreeBSD.org/src/commit/?id=535984471641c9435854ed23a80e49a8bf0e53b0

commit 535984471641c9435854ed23a80e49a8bf0e53b0
Author:     Mark Johnston <[email protected]>
AuthorDate: 2022-11-02 17:08:07 +0000
Commit:     Mark Johnston <[email protected]>
CommitDate: 2024-01-22 18:45:03 +0000

    inpcb: Allow SO_REUSEPORT_LB to be used in jails
    
    Currently SO_REUSEPORT_LB silently does nothing when set by a jailed
    process.  It is trivial to support this option in VNET jails, but it's
    also useful in traditional jails.
    
    This patch enables LB groups in jails with the following semantics:
    - all PCBs in a group must belong to the same jail,
    - PCB lookup prefers jailed groups to non-jailed groups
    
    This is a straightforward extension of the semantics used for individual
    listening sockets.  One pre-existing quirk of the lbgroup implementation
    is that non-jailed lbgroups are searched before jailed listening
    sockets; that is preserved with this change.
    
    Discussed with: glebius
    MFC after:      1 month
    Sponsored by:   Modirum MDPay
    Sponsored by:   Klara, Inc.
    Differential Revision:  https://reviews.freebsd.org/D37029
    
    (cherry picked from commit d93ec8cb1324d04d7cae19fb7fa98ade2ff33c80)
---
 sys/netinet/in_pcb.c   | 126 ++++++++++++++++++++++++++++---------------------
 sys/netinet/in_pcb.h   |   3 +-
 sys/netinet6/in6_pcb.c |  99 +++++++++++++++++++++++---------------
 3 files changed, 136 insertions(+), 92 deletions(-)

diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c
index 3f8db305d3c9..001fd735cb4c 100644
--- a/sys/netinet/in_pcb.c
+++ b/sys/netinet/in_pcb.c
@@ -254,8 +254,8 @@ SYSCTL_COUNTER_U64(_net_inet_ip_rl, OID_AUTO, chgrl, 
CTLFLAG_RD,
  */
 
 static struct inpcblbgroup *
-in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, u_char vflag,
-    uint16_t port, const union in_dependaddr *addr, int size,
+in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred,
+    u_char vflag, uint16_t port, const union in_dependaddr *addr, int size,
     uint8_t numa_domain)
 {
        struct inpcblbgroup *grp;
@@ -263,8 +263,9 @@ in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, u_char 
vflag,
 
        bytes = __offsetof(struct inpcblbgroup, il_inp[size]);
        grp = malloc(bytes, M_PCB, M_ZERO | M_NOWAIT);
-       if (!grp)
+       if (grp == NULL)
                return (NULL);
+       grp->il_cred = crhold(cred);
        grp->il_vflag = vflag;
        grp->il_lport = port;
        grp->il_numa_domain = numa_domain;
@@ -280,6 +281,7 @@ in_pcblbgroup_free_deferred(epoch_context_t ctx)
        struct inpcblbgroup *grp;
 
        grp = __containerof(ctx, struct inpcblbgroup, il_epoch_ctx);
+       crfree(grp->il_cred);
        free(grp, M_PCB);
 }
 
@@ -298,7 +300,7 @@ in_pcblbgroup_resize(struct inpcblbgrouphead *hdr,
        struct inpcblbgroup *grp;
        int i;
 
-       grp = in_pcblbgroup_alloc(hdr, old_grp->il_vflag,
+       grp = in_pcblbgroup_alloc(hdr, old_grp->il_cred, old_grp->il_vflag,
            old_grp->il_lport, &old_grp->il_dependladdr, size,
            old_grp->il_numa_domain);
        if (grp == NULL)
@@ -357,12 +359,6 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t 
numa_domain)
        INP_WLOCK_ASSERT(inp);
        INP_HASH_WLOCK_ASSERT(pcbinfo);
 
-       /*
-        * Don't allow jailed socket to join local group.
-        */
-       if (inp->inp_socket != NULL && jailed(inp->inp_socket->so_cred))
-               return (0);
-
 #ifdef INET6
        /*
         * Don't allow IPv4 mapped INET6 wild socket.
@@ -377,17 +373,19 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t 
numa_domain)
        idx = INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask);
        hdr = &pcbinfo->ipi_lbgrouphashbase[idx];
        CK_LIST_FOREACH(grp, hdr, il_list) {
-               if (grp->il_vflag == inp->inp_vflag &&
+               if (grp->il_cred->cr_prison == inp->inp_cred->cr_prison &&
+                   grp->il_vflag == inp->inp_vflag &&
                    grp->il_lport == inp->inp_lport &&
                    grp->il_numa_domain == numa_domain &&
                    memcmp(&grp->il_dependladdr,
                    &inp->inp_inc.inc_ie.ie_dependladdr,
-                   sizeof(grp->il_dependladdr)) == 0)
+                   sizeof(grp->il_dependladdr)) == 0) {
                        break;
+               }
        }
        if (grp == NULL) {
                /* Create new load balance group. */
-               grp = in_pcblbgroup_alloc(hdr, inp->inp_vflag,
+               grp = in_pcblbgroup_alloc(hdr, inp->inp_cred, inp->inp_vflag,
                    inp->inp_lport, &inp->inp_inc.inc_ie.ie_dependladdr,
                    INPCBLBGROUP_SIZMIN, numa_domain);
                if (grp == NULL)
@@ -2084,15 +2082,20 @@ in_pcblookup_local(struct inpcbinfo *pcbinfo, struct 
in_addr laddr,
 }
 #undef INP_LOOKUP_MAPPED_PCB_COST
 
+static bool
+in_pcblookup_lb_numa_match(const struct inpcblbgroup *grp, int domain)
+{
+       return (domain == M_NODOM || domain == grp->il_numa_domain);
+}
+
 static struct inpcb *
 in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
     const struct in_addr *laddr, uint16_t lport, const struct in_addr *faddr,
-    uint16_t fport, int lookupflags, int numa_domain)
+    uint16_t fport, int lookupflags, int domain)
 {
-       struct inpcb *local_wild, *numa_wild;
        const struct inpcblbgrouphead *hdr;
        struct inpcblbgroup *grp;
-       uint32_t idx;
+       struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild;
 
        INP_HASH_LOCK_ASSERT(pcbinfo);
 
@@ -2100,17 +2103,15 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
            INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)];
 
        /*
-        * Order of socket selection:
-        * 1. non-wild.
-        * 2. wild (if lookupflags contains INPLOOKUP_WILDCARD).
-        *
-        * NOTE:
-        * - Load balanced group does not contain jailed sockets
-        * - Load balanced group does not contain IPv4 mapped INET6 wild sockets
+        * Search for an LB group match based on the following criteria:
+        * - prefer jailed groups to non-jailed groups
+        * - prefer exact source address matches to wildcard matches
+        * - prefer groups bound to the specified NUMA domain
         */
-       local_wild = NULL;
-       numa_wild = NULL;
+       jail_exact = jail_wild = local_exact = local_wild = NULL;
        CK_LIST_FOREACH(grp, hdr, il_list) {
+               bool injail;
+
 #ifdef INET6
                if (!(grp->il_vflag & INP_IPV4))
                        continue;
@@ -2118,27 +2119,47 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
                if (grp->il_lport != lport)
                        continue;
 
-               idx = INP_PCBLBGROUP_PKTHASH(faddr->s_addr, lport, fport) %
-                   grp->il_inpcnt;
+               injail = prison_flag(grp->il_cred, PR_IP4) != 0;
+               if (injail && prison_check_ip4_locked(grp->il_cred->cr_prison,
+                   laddr) != 0)
+                       continue;
+
                if (grp->il_laddr.s_addr == laddr->s_addr) {
-                       if (numa_domain == M_NODOM ||
-                           grp->il_numa_domain == numa_domain) {
-                               return (grp->il_inp[idx]);
-                       } else {
-                               numa_wild = grp->il_inp[idx];
+                       if (injail) {
+                               jail_exact = grp;
+                               if (in_pcblookup_lb_numa_match(grp, domain))
+                                       /* This is a perfect match. */
+                                       goto out;
+                       } else if (local_exact == NULL ||
+                           in_pcblookup_lb_numa_match(grp, domain)) {
+                               local_exact = grp;
+                       }
+               } else if (grp->il_laddr.s_addr == INADDR_ANY &&
+                   (lookupflags & INPLOOKUP_WILDCARD) != 0) {
+                       if (injail) {
+                               if (jail_wild == NULL ||
+                                   in_pcblookup_lb_numa_match(grp, domain))
+                                       jail_wild = grp;
+                       } else if (local_wild == NULL ||
+                           in_pcblookup_lb_numa_match(grp, domain)) {
+                               local_wild = grp;
                        }
-               }
-               if (grp->il_laddr.s_addr == INADDR_ANY &&
-                   (lookupflags & INPLOOKUP_WILDCARD) != 0 &&
-                   (local_wild == NULL || numa_domain == M_NODOM ||
-                       grp->il_numa_domain == numa_domain)) {
-                       local_wild = grp->il_inp[idx];
                }
        }
-       if (numa_wild != NULL)
-               return (numa_wild);
 
-       return (local_wild);
+       if (jail_exact != NULL)
+               grp = jail_exact;
+       else if (jail_wild != NULL)
+               grp = jail_wild;
+       else if (local_exact != NULL)
+               grp = local_exact;
+       else
+               grp = local_wild;
+       if (grp == NULL)
+               return (NULL);
+out:
+       return (grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr->s_addr, lport, fport) 
%
+           grp->il_inpcnt]);
 }
 
 #ifdef PCBGROUP
@@ -2424,16 +2445,6 @@ in_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, 
struct in_addr faddr,
        if (tmpinp != NULL)
                return (tmpinp);
 
-       /*
-        * Then look in lb group (for wildcard match).
-        */
-       if ((lookupflags & INPLOOKUP_WILDCARD) != 0) {
-               inp = in_pcblookup_lbgroup(pcbinfo, &laddr, lport, &faddr,
-                   fport, lookupflags, numa_domain);
-               if (inp != NULL)
-                       return (inp);
-       }
-
        /*
         * Then look for a wildcard match, if requested.
         */
@@ -2445,6 +2456,15 @@ in_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, 
struct in_addr faddr,
                struct inpcb *jail_wild = NULL;
                int injail;
 
+               /*
+                * First see if an LB group matches the request before scanning
+                * all sockets on this port.
+                */
+               inp = in_pcblookup_lbgroup(pcbinfo, &laddr, lport, &faddr,
+                   fport, lookupflags, numa_domain);
+               if (inp != NULL)
+                       return (inp);
+
                /*
                 * Order of socket selection - we always prefer jails.
                 *      1. jailed, non-wild.
@@ -2791,8 +2811,8 @@ in_pcbremlists(struct inpcb *inp)
 
                INP_HASH_WLOCK(pcbinfo);
 
-               /* XXX: Only do if SO_REUSEPORT_LB set? */
-               in_pcbremlbgrouphash(inp);
+               if ((inp->inp_flags2 & INP_REUSEPORT_LB) != 0)
+                       in_pcbremlbgrouphash(inp);
 
                CK_LIST_REMOVE(inp, inp_hash);
                CK_LIST_REMOVE(inp, inp_portlist);
diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h
index a1ec40d21194..331635fa94fb 100644
--- a/sys/netinet/in_pcb.h
+++ b/sys/netinet/in_pcb.h
@@ -564,9 +564,10 @@ struct inpcbgroup {
 struct inpcblbgroup {
        CK_LIST_ENTRY(inpcblbgroup) il_list;
        struct epoch_context il_epoch_ctx;
+       struct ucred    *il_cred;
        uint16_t        il_lport;                       /* (c) */
        u_char          il_vflag;                       /* (c) */
-       u_int8_t                il_numa_domain;
+       uint8_t         il_numa_domain;
        uint32_t        il_pad2;
        union in_dependaddr il_dependladdr;             /* (c) */
 #define        il_laddr        il_dependladdr.id46_addr.ia46_addr4
diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c
index 6d7c1d4b65b5..ee32fbbf1688 100644
--- a/sys/netinet6/in6_pcb.c
+++ b/sys/netinet6/in6_pcb.c
@@ -897,15 +897,20 @@ in6_rtchange(struct inpcb *inp, int errno __unused)
        return inp;
 }
 
+static bool
+in6_pcblookup_lb_numa_match(const struct inpcblbgroup *grp, int domain)
+{
+       return (domain == M_NODOM || domain == grp->il_numa_domain);
+}
+
 static struct inpcb *
 in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
     const struct in6_addr *laddr, uint16_t lport, const struct in6_addr *faddr,
-    uint16_t fport, int lookupflags, uint8_t numa_domain)
+    uint16_t fport, int lookupflags, uint8_t domain)
 {
-       struct inpcb *local_wild, *numa_wild;
        const struct inpcblbgrouphead *hdr;
        struct inpcblbgroup *grp;
-       uint32_t idx;
+       struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild;
 
        INP_HASH_LOCK_ASSERT(pcbinfo);
 
@@ -913,17 +918,15 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
            INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)];
 
        /*
-        * Order of socket selection:
-        * 1. non-wild.
-        * 2. wild (if lookupflags contains INPLOOKUP_WILDCARD).
-        *
-        * NOTE:
-        * - Load balanced group does not contain jailed sockets.
-        * - Load balanced does not contain IPv4 mapped INET6 wild sockets.
+        * Search for an LB group match based on the following criteria:
+        * - prefer jailed groups to non-jailed groups
+        * - prefer exact source address matches to wildcard matches
+        * - prefer groups bound to the specified NUMA domain 
         */
-       local_wild = NULL;
-       numa_wild = NULL;
+       jail_exact = jail_wild = local_exact = local_wild = NULL;
        CK_LIST_FOREACH(grp, hdr, il_list) {
+               bool injail;
+
 #ifdef INET
                if (!(grp->il_vflag & INP_IPV6))
                        continue;
@@ -931,26 +934,47 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
                if (grp->il_lport != lport)
                        continue;
 
-               idx = INP_PCBLBGROUP_PKTHASH(INP6_PCBHASHKEY(faddr), lport,
-                   fport) % grp->il_inpcnt;
+               injail = prison_flag(grp->il_cred, PR_IP6) != 0;
+               if (injail && prison_check_ip6_locked(grp->il_cred->cr_prison,
+                   laddr) != 0)
+                       continue;
+
                if (IN6_ARE_ADDR_EQUAL(&grp->il6_laddr, laddr)) {
-                       if (numa_domain == M_NODOM ||
-                           grp->il_numa_domain == numa_domain) {
-                               return (grp->il_inp[idx]);
+                       if (injail) {
+                               jail_exact = grp;
+                               if (in6_pcblookup_lb_numa_match(grp, domain))
+                                       /* This is a perfect match. */
+                                       goto out;
+                       } else if (local_exact == NULL ||
+                           in6_pcblookup_lb_numa_match(grp, domain)) {
+                               local_exact = grp;
+                       }
+               } else if (IN6_IS_ADDR_UNSPECIFIED(&grp->il6_laddr) &&
+                   (lookupflags & INPLOOKUP_WILDCARD) != 0) {
+                       if (injail) {
+                               if (jail_wild == NULL ||
+                                   in6_pcblookup_lb_numa_match(grp, domain))
+                                       jail_wild = grp;
+                       } else if (local_wild == NULL ||
+                           in6_pcblookup_lb_numa_match(grp, domain)) {
+                               local_wild = grp;
                        }
-                       else
-                               numa_wild = grp->il_inp[idx];
-               }
-               if (IN6_IS_ADDR_UNSPECIFIED(&grp->il6_laddr) &&
-                   (lookupflags & INPLOOKUP_WILDCARD) != 0 &&
-                   (local_wild == NULL || numa_domain == M_NODOM ||
-                       grp->il_numa_domain == numa_domain)) {
-                       local_wild = grp->il_inp[idx];
                }
        }
-       if (numa_wild != NULL)
-               return (numa_wild);
-       return (local_wild);
+
+       if (jail_exact != NULL)
+               grp = jail_exact;
+       else if (jail_wild != NULL)
+               grp = jail_wild;
+       else if (local_exact != NULL)
+               grp = local_exact;
+       else
+               grp = local_wild;
+       if (grp == NULL)
+               return (NULL);
+out:
+       return (grp->il_inp[INP_PCBLBGROUP_PKTHASH(INP6_PCBHASHKEY(faddr), 
lport, fport) %
+           grp->il_inpcnt]);
 }
 
 #ifdef PCBGROUP
@@ -1199,16 +1223,6 @@ in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, 
struct in6_addr *faddr,
        if (tmpinp != NULL)
                return (tmpinp);
 
-       /*
-        * Then look in lb group (for wildcard match).
-        */
-       if ((lookupflags & INPLOOKUP_WILDCARD) != 0) {
-               inp = in6_pcblookup_lbgroup(pcbinfo, laddr, lport, faddr,
-                   fport, lookupflags, numa_domain);
-               if (inp != NULL)
-                       return (inp);
-       }
-
        /*
         * Then look for a wildcard match, if requested.
         */
@@ -1217,6 +1231,15 @@ in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, 
struct in6_addr *faddr,
                struct inpcb *jail_wild = NULL;
                int injail;
 
+               /*
+                * First see if an LB group matches the request before scanning
+                * all sockets on this port.
+                */
+               inp = in6_pcblookup_lbgroup(pcbinfo, laddr, lport, faddr,
+                   fport, lookupflags, numa_domain);
+               if (inp != NULL)
+                       return (inp);
+
                /*
                 * Order of socket selection - we always prefer jails.
                 *      1. jailed, non-wild.

Reply via email to