commit:     2d4b45c54778aa120dd8467beb7e9a3c42005258
Author:     Mike Pagano <mpagano <AT> gentoo <DOT> org>
AuthorDate: Mon Jun  7 11:22:42 2021 +0000
Commit:     Mike Pagano <mpagano <AT> gentoo <DOT> org>
CommitDate: Mon Jun  7 11:22:42 2021 +0000
URL:        https://gitweb.gentoo.org/proj/linux-patches.git/commit/?id=2d4b45c5

Upgrade wireguard patch to v1.0.202100606

Signed-off-by: Mike Pagano <mpagano <AT> gentoo.org>

 0000_README                                        |   2 +-
 ... => 2400_wireguard-backport-v1.0.20210606.patch | 497 +++++++++++----------
 2 files changed, 271 insertions(+), 228 deletions(-)

diff --git a/0000_README b/0000_README
index f6d1278..fbcce52 100644
--- a/0000_README
+++ b/0000_README
@@ -551,7 +551,7 @@ Patch:  
2000_BT-Check-key-sizes-only-if-Secure-Simple-Pairing-enabled.patch
 From:   
https://lore.kernel.org/linux-bluetooth/20190522070540.48895-1-mar...@holtmann.org/raw
 Desc:   Bluetooth: Check key sizes only when Secure Simple Pairing is enabled. 
See bug #686758
 
-Patch:  2400_wireguard-backport-v1.0.20210424.patch
+Patch:  2400_wireguard-backport-v1.0.202100606.patch
 From:   https://git.zx2c4.com/wireguard-linux/
 Desc:   Extremely simple yet fast and modern VPN that utilizes 
state-of-the-art cryptography
 

diff --git a/2400_wireguard-backport-v1.0.20210424.patch 
b/2400_wireguard-backport-v1.0.20210606.patch
similarity index 99%
rename from 2400_wireguard-backport-v1.0.20210424.patch
rename to 2400_wireguard-backport-v1.0.20210606.patch
index 34d7aa5..a5b7b80 100755
--- a/2400_wireguard-backport-v1.0.20210424.patch
+++ b/2400_wireguard-backport-v1.0.20210606.patch
@@ -3380,7 +3380,7 @@ exit 0
 -      u32 u[5];
 -      /* ... silently appended r^3 and r^4 when using AVX2 */
 +asmlinkage void poly1305_init_x86_64(void *ctx,
-+                                   const u8 key[POLY1305_KEY_SIZE]);
++                                   const u8 key[POLY1305_BLOCK_SIZE]);
 +asmlinkage void poly1305_blocks_x86_64(void *ctx, const u8 *inp,
 +                                     const size_t len, const u32 padbit);
 +asmlinkage void poly1305_emit_x86_64(void *ctx, u8 mac[POLY1305_DIGEST_SIZE],
@@ -3462,7 +3462,7 @@ exit 0
 +}
  
 -      return crypto_poly1305_init(desc);
-+static void poly1305_simd_init(void *ctx, const u8 key[POLY1305_KEY_SIZE])
++static void poly1305_simd_init(void *ctx, const u8 key[POLY1305_BLOCK_SIZE])
 +{
 +      poly1305_init_x86_64(ctx, key);
  }
@@ -3523,7 +3523,7 @@ exit 0
  
 -      BUILD_BUG_ON(offsetof(struct poly1305_simd_desc_ctx, base));
 -      sctx = container_of(dctx, struct poly1305_simd_desc_ctx, base);
-+void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 *key)
++void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 
key[POLY1305_KEY_SIZE])
 +{
 +      poly1305_simd_init(&dctx->h, key);
 +      dctx->s[0] = get_unaligned_le32(&key[16]);
@@ -4104,7 +4104,7 @@ exit 0
        .digestsize     = POLY1305_DIGEST_SIZE,
 --- b/include/crypto/internal/poly1305.h
 +++ b/include/crypto/internal/poly1305.h
-@@ -0,0 +1,33 @@
+@@ -0,0 +1,34 @@
 +/* SPDX-License-Identifier: GPL-2.0 */
 +/*
 + * Common values for the Poly1305 algorithm
@@ -4125,7 +4125,8 @@ exit 0
 + * only the ε-almost-∆-universal hash function (not the full MAC) is computed.
 + */
 +
-+void poly1305_core_setkey(struct poly1305_core_key *key, const u8 *raw_key);
++void poly1305_core_setkey(struct poly1305_core_key *key,
++                        const u8 raw_key[POLY1305_BLOCK_SIZE]);
 +static inline void poly1305_core_init(struct poly1305_state *state)
 +{
 +      *state = (struct poly1305_state){};
@@ -4140,7 +4141,7 @@ exit 0
 +#endif
 --- b/include/crypto/poly1305.h
 +++ b/include/crypto/poly1305.h
-@@ -14,51 +14,84 @@
+@@ -14,51 +14,86 @@
  #define POLY1305_DIGEST_SIZE  16
  
 +/* The poly1305_key and poly1305_state types are mostly opaque and
@@ -4206,8 +4207,10 @@ exit 0
 - */
 -void poly1305_core_setkey(struct poly1305_key *key, const u8 *raw_key);
 -static inline void poly1305_core_init(struct poly1305_state *state)
-+void poly1305_init_arch(struct poly1305_desc_ctx *desc, const u8 *key);
-+void poly1305_init_generic(struct poly1305_desc_ctx *desc, const u8 *key);
++void poly1305_init_arch(struct poly1305_desc_ctx *desc,
++                      const u8 key[POLY1305_KEY_SIZE]);
++void poly1305_init_generic(struct poly1305_desc_ctx *desc,
++                         const u8 key[POLY1305_KEY_SIZE]);
 +
 +static inline void poly1305_init(struct poly1305_desc_ctx *desc, const u8 
*key)
 +{
@@ -4258,7 +4261,7 @@ exit 0
  #endif
 --- b/lib/crypto/poly1305.c
 +++ b/lib/crypto/poly1305.c
-@@ -0,0 +1,77 @@
+@@ -0,0 +1,78 @@
 +// SPDX-License-Identifier: GPL-2.0-or-later
 +/*
 + * Poly1305 authenticator algorithm, RFC7539
@@ -4273,7 +4276,8 @@ exit 0
 +#include <linux/module.h>
 +#include <asm/unaligned.h>
 +
-+void poly1305_init_generic(struct poly1305_desc_ctx *desc, const u8 *key)
++void poly1305_init_generic(struct poly1305_desc_ctx *desc,
++                         const u8 key[POLY1305_KEY_SIZE])
 +{
 +      poly1305_core_setkey(&desc->core_r, key);
 +      desc->s[0] = get_unaligned_le32(key + 16);
@@ -6150,7 +6154,7 @@ exit 0
 +
 +static __ro_after_init DEFINE_STATIC_KEY_FALSE(have_neon);
 +
-+void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 *key)
++void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 
key[POLY1305_KEY_SIZE])
 +{
 +      poly1305_init_arm64(&dctx->h, key);
 +      dctx->s[0] = get_unaligned_le32(key + 16);
@@ -8788,7 +8792,7 @@ exit 0
 +
 +static __ro_after_init DEFINE_STATIC_KEY_FALSE(have_neon);
 +
-+void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 *key)
++void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 
key[POLY1305_KEY_SIZE])
 +{
 +      poly1305_init_arm(&dctx->h, key);
 +      dctx->s[0] = get_unaligned_le32(key + 16);
@@ -9052,7 +9056,7 @@ exit 0
 +asmlinkage void poly1305_blocks_mips(void *state, const u8 *src, u32 len, u32 
hibit);
 +asmlinkage void poly1305_emit_mips(void *state, u8 *digest, const u32 *nonce);
 +
-+void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 *key)
++void poly1305_init_arch(struct poly1305_desc_ctx *dctx, const u8 
key[POLY1305_KEY_SIZE])
 +{
 +      poly1305_init_mips(&dctx->h, key);
 +      dctx->s[0] = get_unaligned_le32(key + 16);
@@ -30620,9 +30624,9 @@ exit 0
        u32 nh_key[NH_KEY_WORDS];
  };
  
---- /dev/null
+--- b/lib/crypto/poly1305-donna32.c
 +++ b/lib/crypto/poly1305-donna32.c
-@@ -0,0 +1,204 @@
+@@ -0,0 +1,205 @@
 +// SPDX-License-Identifier: GPL-2.0 OR MIT
 +/*
 + * Copyright (C) 2015-2019 Jason A. Donenfeld <ja...@zx2c4.com>. All Rights 
Reserved.
@@ -30635,7 +30639,8 @@ exit 0
 +#include <asm/unaligned.h>
 +#include <crypto/internal/poly1305.h>
 +
-+void poly1305_core_setkey(struct poly1305_core_key *key, const u8 raw_key[16])
++void poly1305_core_setkey(struct poly1305_core_key *key,
++                        const u8 raw_key[POLY1305_BLOCK_SIZE])
 +{
 +      /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */
 +      key->key.r[0] = (get_unaligned_le32(&raw_key[0])) & 0x3ffffff;
@@ -30827,9 +30832,9 @@ exit 0
 +      put_unaligned_le32(h3, &mac[12]);
 +}
 +EXPORT_SYMBOL(poly1305_core_emit);
---- /dev/null
+--- b/lib/crypto/poly1305-donna64.c
 +++ b/lib/crypto/poly1305-donna64.c
-@@ -0,0 +1,185 @@
+@@ -0,0 +1,186 @@
 +// SPDX-License-Identifier: GPL-2.0 OR MIT
 +/*
 + * Copyright (C) 2015-2019 Jason A. Donenfeld <ja...@zx2c4.com>. All Rights 
Reserved.
@@ -30844,7 +30849,8 @@ exit 0
 +
 +typedef __uint128_t u128;
 +
-+void poly1305_core_setkey(struct poly1305_core_key *key, const u8 raw_key[16])
++void poly1305_core_setkey(struct poly1305_core_key *key,
++                        const u8 raw_key[POLY1305_BLOCK_SIZE])
 +{
 +      u64 t0, t1;
 +
@@ -35909,7 +35915,7 @@ exit 0
 +MODULE_AUTHOR("Jason A. Donenfeld <ja...@zx2c4.com>");
 --- a/arch/x86/Makefile
 +++ b/arch/x86/Makefile
-@@ -197,9 +197,10 @@ avx2_instr :=$(call as-instr,vpbroadcastb 
%xmm0$(comma)%ymm1,-DCONFIG_AS_AVX2=1)
+@@ -198,9 +198,10 @@ avx2_instr :=$(call as-instr,vpbroadcastb 
%xmm0$(comma)%ymm1,-DCONFIG_AS_AVX2=1)
  avx512_instr :=$(call as-instr,vpmovm2b %k1$(comma)%zmm5,-DCONFIG_AS_AVX512=1)
  sha1_ni_instr :=$(call as-instr,sha1msg1 
%xmm0$(comma)%xmm1,-DCONFIG_AS_SHA1_NI=1)
  sha256_ni_instr :=$(call as-instr,sha256msg1 
%xmm0$(comma)%xmm1,-DCONFIG_AS_SHA256_NI=1)
@@ -36513,11 +36519,10 @@ exit 0
  obj-$(CONFIG_EQUALIZER) += eql.o
  obj-$(CONFIG_IFB) += ifb.o
  obj-$(CONFIG_MACSEC) += macsec.o
---- /dev/null
+--- b/drivers/net/wireguard/Makefile
 +++ b/drivers/net/wireguard/Makefile
-@@ -0,0 +1,18 @@
-+ccflags-y := -O3
-+ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
+@@ -0,0 +1,17 @@
++ccflags-y := -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
 +ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG
 +wireguard-y := main.o
 +wireguard-y += noise.o
@@ -36536,7 +36541,7 @@ exit 0
 +obj-$(CONFIG_WIREGUARD) := wireguard.o
 --- b/drivers/net/wireguard/allowedips.c
 +++ b/drivers/net/wireguard/allowedips.c
-@@ -0,0 +1,377 @@
+@@ -0,0 +1,386 @@
 +// SPDX-License-Identifier: GPL-2.0
 +/*
 + * Copyright (C) 2015-2019 Jason A. Donenfeld <ja...@zx2c4.com>. All Rights 
Reserved.
@@ -36545,6 +36550,8 @@ exit 0
 +#include "allowedips.h"
 +#include "peer.h"
 +
++static struct kmem_cache *node_cache;
++
 +static void swap_endian(u8 *dst, const u8 *src, u8 bits)
 +{
 +      if (bits == 32) {
@@ -36567,8 +36574,11 @@ exit 0
 +      node->bitlen = bits;
 +      memcpy(node->bits, src, bits / 8U);
 +}
-+#define CHOOSE_NODE(parent, key) \
-+      parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
++
++static inline u8 choose(struct allowedips_node *node, const u8 *key)
++{
++      return (key[node->bit_at_a] >> node->bit_at_b) & 1;
++}
 +
 +static void push_rcu(struct allowedips_node **stack,
 +                   struct allowedips_node __rcu *p, unsigned int *len)
@@ -36579,6 +36589,11 @@ exit 0
 +      }
 +}
 +
++static void node_free_rcu(struct rcu_head *rcu)
++{
++      kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, 
rcu));
++}
++
 +static void root_free_rcu(struct rcu_head *rcu)
 +{
 +      struct allowedips_node *node, *stack[128] = {
@@ -36588,7 +36603,7 @@ exit 0
 +      while (len > 0 && (node = stack[--len])) {
 +              push_rcu(stack, node->bit[0], &len);
 +              push_rcu(stack, node->bit[1], &len);
-+              kfree(node);
++              kmem_cache_free(node_cache, node);
 +      }
 +}
 +
@@ -36605,60 +36620,6 @@ exit 0
 +      }
 +}
 +
-+static void walk_remove_by_peer(struct allowedips_node __rcu **top,
-+                              struct wg_peer *peer, struct mutex *lock)
-+{
-+#define REF(p) rcu_access_pointer(p)
-+#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
-+#define PUSH(p) ({                                                            
 \
-+              WARN_ON(IS_ENABLED(DEBUG) && len >= 128);                      \
-+              stack[len++] = p;                                              \
-+      })
-+
-+      struct allowedips_node __rcu **stack[128], **nptr;
-+      struct allowedips_node *node, *prev;
-+      unsigned int len;
-+
-+      if (unlikely(!peer || !REF(*top)))
-+              return;
-+
-+      for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
-+              nptr = stack[len - 1];
-+              node = DEREF(nptr);
-+              if (!node) {
-+                      --len;
-+                      continue;
-+              }
-+              if (!prev || REF(prev->bit[0]) == node ||
-+                  REF(prev->bit[1]) == node) {
-+                      if (REF(node->bit[0]))
-+                              PUSH(&node->bit[0]);
-+                      else if (REF(node->bit[1]))
-+                              PUSH(&node->bit[1]);
-+              } else if (REF(node->bit[0]) == prev) {
-+                      if (REF(node->bit[1]))
-+                              PUSH(&node->bit[1]);
-+              } else {
-+                      if (rcu_dereference_protected(node->peer,
-+                              lockdep_is_held(lock)) == peer) {
-+                              RCU_INIT_POINTER(node->peer, NULL);
-+                              list_del_init(&node->peer_list);
-+                              if (!node->bit[0] || !node->bit[1]) {
-+                                      rcu_assign_pointer(*nptr, DEREF(
-+                                             &node->bit[!REF(node->bit[0])]));
-+                                      kfree_rcu(node, rcu);
-+                                      node = DEREF(nptr);
-+                              }
-+                      }
-+                      --len;
-+              }
-+      }
-+
-+#undef REF
-+#undef DEREF
-+#undef PUSH
-+}
-+
 +static unsigned int fls128(u64 a, u64 b)
 +{
 +      return a ? fls64(a) + 64U : fls64(b);
@@ -36698,7 +36659,7 @@ exit 0
 +                      found = node;
 +              if (node->cidr == bits)
 +                      break;
-+              node = rcu_dereference_bh(CHOOSE_NODE(node, key));
++              node = rcu_dereference_bh(node->bit[choose(node, key)]);
 +      }
 +      return found;
 +}
@@ -36730,8 +36691,7 @@ exit 0
 +                         u8 cidr, u8 bits, struct allowedips_node **rnode,
 +                         struct mutex *lock)
 +{
-+      struct allowedips_node *node = rcu_dereference_protected(trie,
-+                                              lockdep_is_held(lock));
++      struct allowedips_node *node = rcu_dereference_protected(trie, 
lockdep_is_held(lock));
 +      struct allowedips_node *parent = NULL;
 +      bool exact = false;
 +
@@ -36741,13 +36701,24 @@ exit 0
 +                      exact = true;
 +                      break;
 +              }
-+              node = rcu_dereference_protected(CHOOSE_NODE(parent, key),
-+                                               lockdep_is_held(lock));
++              node = rcu_dereference_protected(parent->bit[choose(parent, 
key)], lockdep_is_held(lock));
 +      }
 +      *rnode = parent;
 +      return exact;
 +}
 +
++static inline void connect_node(struct allowedips_node **parent, u8 bit, 
struct allowedips_node *node)
++{
++      node->parent_bit_packed = (unsigned long)parent | bit;
++      rcu_assign_pointer(*parent, node);
++}
++
++static inline void choose_and_connect_node(struct allowedips_node *parent, 
struct allowedips_node *node)
++{
++      u8 bit = choose(parent, node->bits);
++      connect_node(&parent->bit[bit], bit, node);
++}
++
 +static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
 +             u8 cidr, struct wg_peer *peer, struct mutex *lock)
 +{
@@ -36757,13 +36728,13 @@ exit 0
 +              return -EINVAL;
 +
 +      if (!rcu_access_pointer(*trie)) {
-+              node = kzalloc(sizeof(*node), GFP_KERNEL);
++              node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
 +              if (unlikely(!node))
 +                      return -ENOMEM;
 +              RCU_INIT_POINTER(node->peer, peer);
 +              list_add_tail(&node->peer_list, &peer->allowedips_list);
 +              copy_and_assign_cidr(node, key, cidr, bits);
-+              rcu_assign_pointer(*trie, node);
++              connect_node(trie, 2, node);
 +              return 0;
 +      }
 +      if (node_placement(*trie, key, cidr, bits, &node, lock)) {
@@ -36772,7 +36743,7 @@ exit 0
 +              return 0;
 +      }
 +
-+      newnode = kzalloc(sizeof(*newnode), GFP_KERNEL);
++      newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL);
 +      if (unlikely(!newnode))
 +              return -ENOMEM;
 +      RCU_INIT_POINTER(newnode->peer, peer);
@@ -36782,10 +36753,10 @@ exit 0
 +      if (!node) {
 +              down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
 +      } else {
-+              down = rcu_dereference_protected(CHOOSE_NODE(node, key),
-+                                               lockdep_is_held(lock));
++              const u8 bit = choose(node, key);
++              down = rcu_dereference_protected(node->bit[bit], 
lockdep_is_held(lock));
 +              if (!down) {
-+                      rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
++                      connect_node(&node->bit[bit], bit, newnode);
 +                      return 0;
 +              }
 +      }
@@ -36793,30 +36764,29 @@ exit 0
 +      parent = node;
 +
 +      if (newnode->cidr == cidr) {
-+              rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
++              choose_and_connect_node(newnode, down);
 +              if (!parent)
-+                      rcu_assign_pointer(*trie, newnode);
++                      connect_node(trie, 2, newnode);
 +              else
-+                      rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
-+                                         newnode);
-+      } else {
-+              node = kzalloc(sizeof(*node), GFP_KERNEL);
-+              if (unlikely(!node)) {
-+                      list_del(&newnode->peer_list);
-+                      kfree(newnode);
-+                      return -ENOMEM;
-+              }
-+              INIT_LIST_HEAD(&node->peer_list);
-+              copy_and_assign_cidr(node, newnode->bits, cidr, bits);
++                      choose_and_connect_node(parent, newnode);
++              return 0;
++      }
 +
-+              rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
-+              rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
-+              if (!parent)
-+                      rcu_assign_pointer(*trie, node);
-+              else
-+                      rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
-+                                         node);
++      node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
++      if (unlikely(!node)) {
++              list_del(&newnode->peer_list);
++              kmem_cache_free(node_cache, newnode);
++              return -ENOMEM;
 +      }
++      INIT_LIST_HEAD(&node->peer_list);
++      copy_and_assign_cidr(node, newnode->bits, cidr, bits);
++
++      choose_and_connect_node(node, down);
++      choose_and_connect_node(node, newnode);
++      if (!parent)
++              connect_node(trie, 2, node);
++      else
++              choose_and_connect_node(parent, node);
 +      return 0;
 +}
 +
@@ -36874,9 +36844,41 @@ exit 0
 +void wg_allowedips_remove_by_peer(struct allowedips *table,
 +                                struct wg_peer *peer, struct mutex *lock)
 +{
++      struct allowedips_node *node, *child, **parent_bit, *parent, *tmp;
++      bool free_parent;
++
++      if (list_empty(&peer->allowedips_list))
++              return;
 +      ++table->seq;
-+      walk_remove_by_peer(&table->root4, peer, lock);
-+      walk_remove_by_peer(&table->root6, peer, lock);
++      list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
++              list_del_init(&node->peer_list);
++              RCU_INIT_POINTER(node->peer, NULL);
++              if (node->bit[0] && node->bit[1])
++                      continue;
++              child = 
rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
++                                                lockdep_is_held(lock));
++              if (child)
++                      child->parent_bit_packed = node->parent_bit_packed;
++              parent_bit = (struct allowedips_node 
**)(node->parent_bit_packed & ~3UL);
++              *parent_bit = child;
++              parent = (void *)parent_bit -
++                       offsetof(struct allowedips_node, 
bit[node->parent_bit_packed & 1]);
++              free_parent = !rcu_access_pointer(node->bit[0]) &&
++                            !rcu_access_pointer(node->bit[1]) &&
++                            (node->parent_bit_packed & 3) <= 1 &&
++                            !rcu_access_pointer(parent->peer);
++              if (free_parent)
++                      child = rcu_dereference_protected(
++                                      parent->bit[!(node->parent_bit_packed & 
1)],
++                                      lockdep_is_held(lock));
++              call_rcu(&node->rcu, node_free_rcu);
++              if (!free_parent)
++                      continue;
++              if (child)
++                      child->parent_bit_packed = parent->parent_bit_packed;
++              *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) 
= child;
++              call_rcu(&parent->rcu, node_free_rcu);
++      }
 +}
 +
 +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
@@ -36913,8 +36915,20 @@ exit 0
 +      return NULL;
 +}
 +
++int __init wg_allowedips_slab_init(void)
++{
++      node_cache = KMEM_CACHE(allowedips_node, 0);
++      return node_cache ? 0 : -ENOMEM;
++}
++
++void wg_allowedips_slab_uninit(void)
++{
++      rcu_barrier();
++      kmem_cache_destroy(node_cache);
++}
++
 +#include "selftest/allowedips.c"
---- /dev/null
+--- b/drivers/net/wireguard/allowedips.h
 +++ b/drivers/net/wireguard/allowedips.h
 @@ -0,0 +1,59 @@
 +/* SPDX-License-Identifier: GPL-2.0 */
@@ -36934,14 +36948,11 @@ exit 0
 +struct allowedips_node {
 +      struct wg_peer __rcu *peer;
 +      struct allowedips_node __rcu *bit[2];
-+      /* While it may seem scandalous that we waste space for v4,
-+       * we're alloc'ing to the nearest power of 2 anyway, so this
-+       * doesn't actually make a difference.
-+       */
-+      u8 bits[16] __aligned(__alignof(u64));
 +      u8 cidr, bit_at_a, bit_at_b, bitlen;
++      u8 bits[16] __aligned(__alignof(u64));
 +
-+      /* Keep rarely used list at bottom to be beyond cache line. */
++      /* Keep rarely used members at bottom to be beyond cache line. */
++      unsigned long parent_bit_packed;
 +      union {
 +              struct list_head peer_list;
 +              struct rcu_head rcu;
@@ -36952,7 +36963,7 @@ exit 0
 +      struct allowedips_node __rcu *root4;
 +      struct allowedips_node __rcu *root6;
 +      u64 seq;
-+};
++} __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 
16-bit alignment. */
 +
 +void wg_allowedips_init(struct allowedips *table);
 +void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
@@ -36975,6 +36986,9 @@ exit 0
 +bool wg_allowedips_selftest(void);
 +#endif
 +
++int wg_allowedips_slab_init(void);
++void wg_allowedips_slab_uninit(void);
++
 +#endif /* _WG_ALLOWEDIPS_H */
 --- /dev/null
 +++ b/drivers/net/wireguard/cookie.c
@@ -37807,7 +37821,7 @@ exit 0
 +#endif /* _WG_DEVICE_H */
 --- b/drivers/net/wireguard/main.c
 +++ b/drivers/net/wireguard/main.c
-@@ -0,0 +1,63 @@
+@@ -0,0 +1,78 @@
 +// SPDX-License-Identifier: GPL-2.0
 +/*
 + * Copyright (C) 2015-2019 Jason A. Donenfeld <ja...@zx2c4.com>. All Rights 
Reserved.
@@ -37831,13 +37845,22 @@ exit 0
 +{
 +      int ret;
 +
++      ret = wg_allowedips_slab_init();
++      if (ret < 0)
++              goto err_allowedips;
++
 +#ifdef DEBUG
++      ret = -ENOTRECOVERABLE;
 +      if (!wg_allowedips_selftest() || !wg_packet_counter_selftest() ||
 +          !wg_ratelimiter_selftest())
-+              return -ENOTRECOVERABLE;
++              goto err_peer;
 +#endif
 +      wg_noise_init();
 +
++      ret = wg_peer_init();
++      if (ret < 0)
++              goto err_peer;
++
 +      ret = wg_device_init();
 +      if (ret < 0)
 +              goto err_device;
@@ -37854,6 +37877,10 @@ exit 0
 +err_netlink:
 +      wg_device_uninit();
 +err_device:
++      wg_peer_uninit();
++err_peer:
++      wg_allowedips_slab_uninit();
++err_allowedips:
 +      return ret;
 +}
 +
@@ -37861,6 +37888,8 @@ exit 0
 +{
 +      wg_genetlink_uninit();
 +      wg_device_uninit();
++      wg_peer_uninit();
++      wg_allowedips_slab_uninit();
 +}
 +
 +module_init(mod_init);
@@ -39637,7 +39666,7 @@ exit 0
 +#endif /* _WG_NOISE_H */
 --- b/drivers/net/wireguard/peer.c
 +++ b/drivers/net/wireguard/peer.c
-@@ -0,0 +1,227 @@
+@@ -0,0 +1,240 @@
 +// SPDX-License-Identifier: GPL-2.0
 +/*
 + * Copyright (C) 2015-2019 Jason A. Donenfeld <ja...@zx2c4.com>. All Rights 
Reserved.
@@ -39655,6 +39684,7 @@ exit 0
 +#include <linux/rcupdate.h>
 +#include <linux/list.h>
 +
++static struct kmem_cache *peer_cache;
 +static atomic64_t peer_counter = ATOMIC64_INIT(0);
 +
 +struct wg_peer *wg_peer_create(struct wg_device *wg,
@@ -39669,10 +39699,10 @@ exit 0
 +      if (wg->num_peers >= MAX_PEERS_PER_DEVICE)
 +              return ERR_PTR(ret);
 +
-+      peer = kzalloc(sizeof(*peer), GFP_KERNEL);
++      peer = kmem_cache_zalloc(peer_cache, GFP_KERNEL);
 +      if (unlikely(!peer))
 +              return ERR_PTR(ret);
-+      if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))
++      if (unlikely(dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)))
 +              goto err;
 +
 +      peer->device = wg;
@@ -39704,7 +39734,7 @@ exit 0
 +      return peer;
 +
 +err:
-+      kfree(peer);
++      kmem_cache_free(peer_cache, peer);
 +      return ERR_PTR(ret);
 +}
 +
@@ -39728,7 +39758,7 @@ exit 0
 +      /* Mark as dead, so that we don't allow jumping contexts after. */
 +      WRITE_ONCE(peer->is_dead, true);
 +
-+      /* The caller must now synchronize_rcu() for this to take effect. */
++      /* The caller must now synchronize_net() for this to take effect. */
 +}
 +
 +static void peer_remove_after_dead(struct wg_peer *peer)
@@ -39800,7 +39830,7 @@ exit 0
 +      lockdep_assert_held(&peer->device->device_update_lock);
 +
 +      peer_make_dead(peer);
-+      synchronize_rcu();
++      synchronize_net();
 +      peer_remove_after_dead(peer);
 +}
 +
@@ -39818,7 +39848,7 @@ exit 0
 +              peer_make_dead(peer);
 +              list_add_tail(&peer->peer_list, &dead_peers);
 +      }
-+      synchronize_rcu();
++      synchronize_net();
 +      list_for_each_entry_safe(peer, temp, &dead_peers, peer_list)
 +              peer_remove_after_dead(peer);
 +}
@@ -39833,7 +39863,8 @@ exit 0
 +      /* The final zeroing takes care of clearing any remaining handshake key
 +       * material and other potentially sensitive information.
 +       */
-+      kzfree(peer);
++      memzero_explicit(peer, sizeof(*peer));
++      kmem_cache_free(peer_cache, peer);
 +}
 +
 +static void kref_release(struct kref *refcount)
@@ -39865,9 +39896,20 @@ exit 0
 +              return;
 +      kref_put(&peer->refcount, kref_release);
 +}
++
++int __init wg_peer_init(void)
++{
++      peer_cache = KMEM_CACHE(wg_peer, 0);
++      return peer_cache ? 0 : -ENOMEM;
++}
++
++void wg_peer_uninit(void)
++{
++      kmem_cache_destroy(peer_cache);
++}
 --- b/drivers/net/wireguard/peer.h
 +++ b/drivers/net/wireguard/peer.h
-@@ -0,0 +1,83 @@
+@@ -0,0 +1,86 @@
 +/* SPDX-License-Identifier: GPL-2.0 */
 +/*
 + * Copyright (C) 2015-2019 Jason A. Donenfeld <ja...@zx2c4.com>. All Rights 
Reserved.
@@ -39950,6 +39992,9 @@ exit 0
 +void wg_peer_remove(struct wg_peer *peer);
 +void wg_peer_remove_all(struct wg_device *wg);
 +
++int wg_peer_init(void);
++void wg_peer_uninit(void);
++
 +#endif /* _WG_PEER_H */
 --- b/drivers/net/wireguard/peerlookup.c
 +++ b/drivers/net/wireguard/peerlookup.c
@@ -41411,9 +41456,9 @@ exit 0
 +err:
 +      dev_kfree_skb(skb);
 +}
---- /dev/null
+--- b/drivers/net/wireguard/selftest/allowedips.c
 +++ b/drivers/net/wireguard/selftest/allowedips.c
-@@ -0,0 +1,683 @@
+@@ -0,0 +1,676 @@
 +// SPDX-License-Identifier: GPL-2.0
 +/*
 + * Copyright (C) 2015-2019 Jason A. Donenfeld <ja...@zx2c4.com>. All Rights 
Reserved.
@@ -41435,32 +41480,22 @@ exit 0
 +
 +#include <linux/siphash.h>
 +
-+static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits,
-+                                            u8 cidr)
-+{
-+      swap_endian(dst, src, bits);
-+      memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8);
-+      if (cidr)
-+              dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8);
-+}
-+
 +static __init void print_node(struct allowedips_node *node, u8 bits)
 +{
 +      char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
-+      char *fmt_declaration = KERN_DEBUG
-+              "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
++      char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, 
color=\"#%06x\"];\n";
++      u8 ip1[16], ip2[16], cidr1, cidr2;
 +      char *style = "dotted";
-+      u8 ip1[16], ip2[16];
 +      u32 color = 0;
 +
++      if (node == NULL)
++              return;
 +      if (bits == 32) {
 +              fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
-+              fmt_declaration = KERN_DEBUG
-+                      "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
++              fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, 
color=\"#%06x\"];\n";
 +      } else if (bits == 128) {
 +              fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
-+              fmt_declaration = KERN_DEBUG
-+                      "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
++              fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, 
color=\"#%06x\"];\n";
 +      }
 +      if (node->peer) {
 +              hsiphash_key_t key = { { 0 } };
@@ -41471,24 +41506,20 @@ exit 0
 +                      hsiphash_1u32(0xabad1dea, &key) % 200;
 +              style = "bold";
 +      }
-+      swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr);
-+      printk(fmt_declaration, ip1, node->cidr, style, color);
++      wg_allowedips_read_node(node, ip1, &cidr1);
++      printk(fmt_declaration, ip1, cidr1, style, color);
 +      if (node->bit[0]) {
-+              swap_endian_and_apply_cidr(ip2,
-+                              rcu_dereference_raw(node->bit[0])->bits, bits,
-+                              node->cidr);
-+              printk(fmt_connection, ip1, node->cidr, ip2,
-+                     rcu_dereference_raw(node->bit[0])->cidr);
-+              print_node(rcu_dereference_raw(node->bit[0]), bits);
++              wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, 
&cidr2);
++              printk(fmt_connection, ip1, cidr1, ip2, cidr2);
 +      }
 +      if (node->bit[1]) {
-+              swap_endian_and_apply_cidr(ip2,
-+                              rcu_dereference_raw(node->bit[1])->bits,
-+                              bits, node->cidr);
-+              printk(fmt_connection, ip1, node->cidr, ip2,
-+                     rcu_dereference_raw(node->bit[1])->cidr);
-+              print_node(rcu_dereference_raw(node->bit[1]), bits);
++              wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, 
&cidr2);
++              printk(fmt_connection, ip1, cidr1, ip2, cidr2);
 +      }
++      if (node->bit[0])
++              print_node(rcu_dereference_raw(node->bit[0]), bits);
++      if (node->bit[1])
++              print_node(rcu_dereference_raw(node->bit[1]), bits);
 +}
 +
 +static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
@@ -41537,8 +41568,8 @@ exit 0
 +{
 +      union nf_inet_addr mask;
 +
-+      memset(&mask, 0x00, 128 / 8);
-+      memset(&mask, 0xff, cidr / 8);
++      memset(&mask, 0, sizeof(mask));
++      memset(&mask.all, 0xff, cidr / 8);
 +      if (cidr % 32)
 +              mask.all[cidr / 32] = (__force u32)htonl(
 +                      (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
@@ -41565,42 +41596,36 @@ exit 0
 +}
 +
 +static __init inline bool
-+horrible_match_v4(const struct horrible_allowedips_node *node,
-+                struct in_addr *ip)
++horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr 
*ip)
 +{
 +      return (ip->s_addr & node->mask.ip) == node->ip.ip;
 +}
 +
 +static __init inline bool
-+horrible_match_v6(const struct horrible_allowedips_node *node,
-+                struct in6_addr *ip)
-+{
-+      return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) ==
-+                     node->ip.ip6[0] &&
-+             (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) ==
-+                     node->ip.ip6[1] &&
-+             (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) ==
-+                     node->ip.ip6[2] &&
++horrible_match_v6(const struct horrible_allowedips_node *node, struct 
in6_addr *ip)
++{
++      return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] 
&&
++             (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] 
&&
++             (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] 
&&
 +             (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
 +}
 +
 +static __init void
-+horrible_insert_ordered(struct horrible_allowedips *table,
-+                      struct horrible_allowedips_node *node)
++horrible_insert_ordered(struct horrible_allowedips *table, struct 
horrible_allowedips_node *node)
 +{
 +      struct horrible_allowedips_node *other = NULL, *where = NULL;
 +      u8 my_cidr = horrible_mask_to_cidr(node->mask);
 +
 +      hlist_for_each_entry(other, &table->head, table) {
-+              if (!memcmp(&other->mask, &node->mask,
-+                          sizeof(union nf_inet_addr)) &&
-+                  !memcmp(&other->ip, &node->ip,
-+                          sizeof(union nf_inet_addr)) &&
-+                  other->ip_version == node->ip_version) {
++              if (other->ip_version == node->ip_version &&
++                  !memcmp(&other->mask, &node->mask, sizeof(union 
nf_inet_addr)) &&
++                  !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) 
{
 +                      other->value = node->value;
 +                      kfree(node);
 +                      return;
 +              }
++      }
++      hlist_for_each_entry(other, &table->head, table) {
 +              where = other;
 +              if (horrible_mask_to_cidr(other->mask) <= my_cidr)
 +                      break;
@@ -41617,8 +41642,7 @@ exit 0
 +horrible_allowedips_insert_v4(struct horrible_allowedips *table,
 +                            struct in_addr *ip, u8 cidr, void *value)
 +{
-+      struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
-+                                                      GFP_KERNEL);
++      struct horrible_allowedips_node *node = kzalloc(sizeof(*node), 
GFP_KERNEL);
 +
 +      if (unlikely(!node))
 +              return -ENOMEM;
@@ -41635,8 +41659,7 @@ exit 0
 +horrible_allowedips_insert_v6(struct horrible_allowedips *table,
 +                            struct in6_addr *ip, u8 cidr, void *value)
 +{
-+      struct horrible_allowedips_node *node = kzalloc(sizeof(*node),
-+                                                      GFP_KERNEL);
++      struct horrible_allowedips_node *node = kzalloc(sizeof(*node), 
GFP_KERNEL);
 +
 +      if (unlikely(!node))
 +              return -ENOMEM;
@@ -41650,39 +41673,43 @@ exit 0
 +}
 +
 +static __init void *
-+horrible_allowedips_lookup_v4(struct horrible_allowedips *table,
-+                            struct in_addr *ip)
++horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct 
in_addr *ip)
 +{
 +      struct horrible_allowedips_node *node;
-+      void *ret = NULL;
 +
 +      hlist_for_each_entry(node, &table->head, table) {
-+              if (node->ip_version != 4)
-+                      continue;
-+              if (horrible_match_v4(node, ip)) {
-+                      ret = node->value;
-+                      break;
-+              }
++              if (node->ip_version == 4 && horrible_match_v4(node, ip))
++                      return node->value;
 +      }
-+      return ret;
++      return NULL;
 +}
 +
 +static __init void *
-+horrible_allowedips_lookup_v6(struct horrible_allowedips *table,
-+                            struct in6_addr *ip)
++horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct 
in6_addr *ip)
 +{
 +      struct horrible_allowedips_node *node;
-+      void *ret = NULL;
 +
 +      hlist_for_each_entry(node, &table->head, table) {
-+              if (node->ip_version != 6)
++              if (node->ip_version == 6 && horrible_match_v6(node, ip))
++                      return node->value;
++      }
++      return NULL;
++}
++
++
++static __init void
++horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void 
*value)
++{
++      struct horrible_allowedips_node *node;
++      struct hlist_node *h;
++
++      hlist_for_each_entry_safe(node, h, &table->head, table) {
++              if (node->value != value)
 +                      continue;
-+              if (horrible_match_v6(node, ip)) {
-+                      ret = node->value;
-+                      break;
-+              }
++              hlist_del(&node->table);
++              kfree(node);
 +      }
-+      return ret;
++
 +}
 +
 +static __init bool randomized_test(void)
@@ -41712,6 +41739,7 @@ exit 0
 +                      goto free;
 +              }
 +              kref_init(&peers[i]->refcount);
++              INIT_LIST_HEAD(&peers[i]->allowedips_list);
 +      }
 +
 +      mutex_lock(&mutex);
@@ -41749,7 +41777,7 @@ exit 0
 +                      if (wg_allowedips_insert_v4(&t,
 +                                                  (struct in_addr *)mutated,
 +                                                  cidr, peer, &mutex) < 0) {
-+                              pr_err("allowedips random malloc: FAIL\n");
++                              pr_err("allowedips random self-test malloc: 
FAIL\n");
 +                              goto free_locked;
 +                      }
 +                      if (horrible_allowedips_insert_v4(&h,
@@ -41812,23 +41840,33 @@ exit 0
 +              print_tree(t.root6, 128);
 +      }
 +
-+      for (i = 0; i < NUM_QUERIES; ++i) {
-+              prandom_bytes(ip, 4);
-+              if (lookup(t.root4, 32, ip) !=
-+                  horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
-+                      pr_err("allowedips random self-test: FAIL\n");
-+                      goto free;
++      for (j = 0;; ++j) {
++              for (i = 0; i < NUM_QUERIES; ++i) {
++                      prandom_bytes(ip, 4);
++                      if (lookup(t.root4, 32, ip) != 
horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
++                              horrible_allowedips_lookup_v4(&h, (struct 
in_addr *)ip);
++                              pr_err("allowedips random v4 self-test: 
FAIL\n");
++                              goto free;
++                      }
++                      prandom_bytes(ip, 16);
++                      if (lookup(t.root6, 128, ip) != 
horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
++                              pr_err("allowedips random v6 self-test: 
FAIL\n");
++                              goto free;
++                      }
 +              }
++              if (j >= NUM_PEERS)
++                      break;
++              mutex_lock(&mutex);
++              wg_allowedips_remove_by_peer(&t, peers[j], &mutex);
++              mutex_unlock(&mutex);
++              horrible_allowedips_remove_by_value(&h, peers[j]);
 +      }
 +
-+      for (i = 0; i < NUM_QUERIES; ++i) {
-+              prandom_bytes(ip, 16);
-+              if (lookup(t.root6, 128, ip) !=
-+                  horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
-+                      pr_err("allowedips random self-test: FAIL\n");
-+                      goto free;
-+              }
++      if (t.root4 || t.root6) {
++              pr_err("allowedips random self-test removal: FAIL\n");
++              goto free;
 +      }
++
 +      ret = true;
 +
 +free:
@@ -43291,7 +43329,7 @@ exit 0
 +      if (new4)
 +              wg->incoming_port = ntohs(inet_sk(new4)->inet_sport);
 +      mutex_unlock(&wg->socket_update_lock);
-+      synchronize_rcu();
++      synchronize_net();
 +      sock_free(old4);
 +      sock_free(old6);
 +}
@@ -43827,7 +43865,7 @@ exit 0
 +#endif /* _WG_UAPI_WIREGUARD_H */
 --- b/tools/testing/selftests/wireguard/netns.sh
 +++ b/tools/testing/selftests/wireguard/netns.sh
-@@ -0,0 +1,635 @@
+@@ -0,0 +1,636 @@
 +#!/bin/bash
 +# SPDX-License-Identifier: GPL-2.0
 +#
@@ -44193,6 +44231,7 @@ exit 0
 +ip1 -4 route add default dev wg0 table 51820
 +ip1 -4 rule add not fwmark 51820 table 51820
 +ip1 -4 rule add table main suppress_prefixlength 0
++n1 bash -c 'printf 0 > /proc/sys/net/ipv4/conf/vethc/rp_filter'
 +# Flood the pings instead of sending just one, to trigger routing table 
reference counting bugs.
 +n1 ping -W 1 -c 100 -f 192.168.99.7
 +n1 ping -W 1 -c 100 -f abab::1111
@@ -45370,7 +45409,7 @@ exit 0
 +}
 --- b/tools/testing/selftests/wireguard/qemu/kernel.config
 +++ b/tools/testing/selftests/wireguard/qemu/kernel.config
-@@ -0,0 +1,90 @@
+@@ -0,0 +1,89 @@
 +CONFIG_LOCALVERSION=""
 +CONFIG_NET=y
 +CONFIG_NETDEVICES=y
@@ -45392,7 +45431,6 @@ exit 0
 +CONFIG_NETFILTER_XT_NAT=y
 +CONFIG_NETFILTER_XT_MATCH_LENGTH=y
 +CONFIG_NETFILTER_XT_MARK=y
-+CONFIG_NF_CONNTRACK_IPV4=y
 +CONFIG_NF_NAT_IPV4=y
 +CONFIG_IP_NF_IPTABLES=y
 +CONFIG_IP_NF_FILTER=y
@@ -45497,3 +45535,8 @@ exit 0
 +
 +const struct header_ops ip_tunnel_header_ops = { .parse_protocol = 
ip_tunnel_parse_protocol };
 +EXPORT_SYMBOL(ip_tunnel_header_ops);
+--- /dev/null
++++ b/arch/mips/crypto/.gitignore
+@@ -0,0 +1,2 @@
++# SPDX-License-Identifier: GPL-2.0-only
++poly1305-core.S

Reply via email to