Replace napi->thread with a new thread_node struct that has a back
pointer to napi_struct.

Make the NAPI thread to use the thread_node as data pointer so that
it can poll on different NAPIs thoughout its lifetime.

Park the thread and save the thread_node in napi_config on napi_del.
Restore the node and unpark the thread on napi_add_config.

Signed-off-by: Shuhao Tan <[email protected]>
---
 include/linux/netdevice.h |  13 +++-
 net/core/dev.c            | 151 +++++++++++++++++++++++++++++---------
 net/core/netdev-genl.c    |  12 ++-
 3 files changed, 139 insertions(+), 37 deletions(-)

diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h
index 7f4f0837c09f..1cda88607e99 100644
--- a/include/linux/netdevice.h
+++ b/include/linux/netdevice.h
@@ -63,6 +63,7 @@ struct dsa_port;
 struct ip_tunnel_parm_kern;
 struct macsec_context;
 struct macsec_ops;
+struct napi_struct;
 struct netdev_config;
 struct netdev_name_node;
 struct sd_flow_limit;
@@ -363,10 +364,20 @@ struct gro_node {
        u32                     cached_napi_id;
 };
 
+/*
+ * Structure for persisting threaded NAPI kthread
+ */
+struct napi_thread_node {
+       struct task_struct *thread;
+       struct napi_struct *napi;
+       struct rcu_head rcu;
+};
+
 /*
  * Structure for per-NAPI config
  */
 struct napi_config {
+       struct napi_thread_node *thread_node;
        u64 gro_flush_timeout;
        u64 irq_suspend_timeout;
        u32 defer_hard_irqs;
@@ -403,7 +414,7 @@ struct napi_struct {
        struct gro_node         gro;
        struct hrtimer          timer;
        /* all fields past this point are write-protected by netdev_lock */
-       struct task_struct      *thread;
+       struct napi_thread_node __rcu *thread_node;
        unsigned long           gro_flush_timeout;
        unsigned long           irq_suspend_timeout;
        u32                     defer_hard_irqs;
diff --git a/net/core/dev.c b/net/core/dev.c
index 202e35acb15b..f5e3b9e526af 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -1645,25 +1645,62 @@ EXPORT_SYMBOL(netdev_notify_peers);
 
 static int napi_threaded_poll(void *data);
 
-static int napi_kthread_create(struct napi_struct *n)
+static int napi_thread_node_create(struct napi_struct *n)
 {
+       struct napi_thread_node *thread_node = NULL;
+       struct task_struct *thread = NULL;
        int err = 0;
 
+       thread_node = kvzalloc_obj(*thread_node);
+       if (!thread_node)
+               return -ENOMEM;
+
        /* Create and wake up the kthread once to put it in
         * TASK_INTERRUPTIBLE mode to avoid the blocked task
         * warning and work with loadavg.
         */
-       n->thread = kthread_run(napi_threaded_poll, n, "napi/%s-%d",
-                               n->dev->name, n->napi_id);
-       if (IS_ERR(n->thread)) {
-               err = PTR_ERR(n->thread);
+       thread_node->napi = n;
+       thread = kthread_run(napi_threaded_poll, thread_node, "napi/%s-%d",
+                            n->dev->name, n->napi_id);
+       if (IS_ERR(thread)) {
+               err = PTR_ERR(thread);
                pr_err("kthread_run failed with err %d\n", err);
-               n->thread = NULL;
+               goto free_thread_node;
        }
 
+       thread_node->thread = thread;
+       rcu_assign_pointer(n->thread_node, thread_node);
+
+       return 0;
+
+free_thread_node:
+       kvfree(thread_node);
+
        return err;
 }
 
+static void napi_thread_node_stop(struct napi_thread_node *thread_node)
+{
+       kthread_stop(thread_node->thread);
+       kvfree_rcu(thread_node, rcu);
+}
+
+static int napi_kthread_create(struct napi_struct *n)
+{
+       struct napi_thread_node *thread_node;
+
+       if (n->config && n->config->thread_node) {
+               thread_node = n->config->thread_node;
+               rcu_assign_pointer(n->thread_node, thread_node);
+               n->config->thread_node = NULL;
+               WRITE_ONCE(thread_node->napi, n);
+               kthread_unpark(thread_node->thread);
+               return 0;
+       }
+
+       return napi_thread_node_create(n);
+}
+
 static int __dev_open(struct net_device *dev, struct netlink_ext_ack *extack)
 {
        const struct net_device_ops *ops = dev->netdev_ops;
@@ -4949,7 +4986,7 @@ EXPORT_SYMBOL(__dev_direct_xmit);
 /*************************************************************************
  *                     Receiver routines
  *************************************************************************/
-static DEFINE_PER_CPU(struct task_struct *, backlog_napi);
+static DEFINE_PER_CPU(struct napi_thread_node, backlog_napi);
 
 int weight_p __read_mostly = 64;           /* old backlog weight */
 int dev_weight_rx_bias __read_mostly = 1;  /* bias for backlog weight */
@@ -4959,10 +4996,11 @@ int dev_weight_tx_bias __read_mostly = 1;  /* bias for 
output_queue quota */
 static inline void ____napi_schedule(struct softnet_data *sd,
                                     struct napi_struct *napi)
 {
-       struct task_struct *thread;
+       struct napi_thread_node *thread_node;
 
        lockdep_assert_irqs_disabled();
 
+       rcu_read_lock();
        if (test_bit(NAPI_STATE_THREADED, &napi->state)) {
                /* Paired with smp_mb__before_atomic() in
                 * napi_enable()/netif_set_threaded().
@@ -4970,18 +5008,21 @@ static inline void ____napi_schedule(struct 
softnet_data *sd,
                 * read on napi->thread. Only call
                 * wake_up_process() when it's not NULL.
                 */
-               thread = READ_ONCE(napi->thread);
-               if (thread) {
-                       if (use_backlog_threads() && thread == 
raw_cpu_read(backlog_napi))
+               thread_node = rcu_dereference(napi->thread_node);
+               if (thread_node) {
+                       if (use_backlog_threads() &&
+                           thread_node == this_cpu_ptr(&backlog_napi))
                                goto use_local_napi;
 
                        set_bit(NAPI_STATE_SCHED_THREADED, &napi->state);
-                       wake_up_process(thread);
+                       wake_up_process(thread_node->thread);
+                       rcu_read_unlock();
                        return;
                }
        }
 
 use_local_napi:
+       rcu_read_unlock();
        DEBUG_NET_WARN_ON_ONCE(!list_empty(&napi->poll_list));
        list_add_tail(&napi->poll_list, &sd->poll_list);
        WRITE_ONCE(napi->list_owner, smp_processor_id());
@@ -7148,6 +7189,7 @@ static enum hrtimer_restart napi_watchdog(struct hrtimer 
*timer)
 
 static void napi_stop_kthread(struct napi_struct *napi)
 {
+       struct napi_thread_node *thread_node;
        unsigned long val, new;
 
        /* Wait until the napi STATE_THREADED is unset. */
@@ -7180,8 +7222,9 @@ static void napi_stop_kthread(struct napi_struct *napi)
                msleep(20);
        }
 
-       kthread_stop(napi->thread);
-       napi->thread = NULL;
+       thread_node = netdev_lock_dereference(napi->thread_node, napi->dev);
+       rcu_assign_pointer(napi->thread_node, NULL);
+       napi_thread_node_stop(thread_node);
 }
 
 static void napi_set_threaded_state(struct napi_struct *napi,
@@ -7197,9 +7240,13 @@ static void napi_set_threaded_state(struct napi_struct 
*napi,
 int napi_set_threaded(struct napi_struct *napi,
                      enum netdev_napi_threaded threaded)
 {
+       struct napi_thread_node *thread_node;
+
+       thread_node = netdev_lock_dereference(napi->thread_node, napi->dev);
+
        if (threaded) {
-               if (!napi->thread) {
-                       int err = napi_kthread_create(napi);
+               if (!thread_node) {
+                       int err = napi_thread_node_create(napi);
 
                        if (err)
                                return err;
@@ -7215,7 +7262,7 @@ int napi_set_threaded(struct napi_struct *napi,
         * softirq mode will happen in the next round of napi_schedule().
         * This should not cause hiccups/stalls to the live traffic.
         */
-       if (!threaded && napi->thread) {
+       if (!threaded && thread_node) {
                napi_stop_kthread(napi);
        } else {
                /* Make sure kthread is created before THREADED bit is set. */
@@ -7236,8 +7283,9 @@ int netif_set_threaded(struct net_device *dev,
 
        if (threaded) {
                list_for_each_entry(napi, &dev->napi_list, dev_list) {
-                       if (!napi->thread) {
-                               err = napi_kthread_create(napi);
+                       /* protected by assertion above */
+                       if (!rcu_dereference_protected(napi->thread_node, 1)) {
+                               err = napi_thread_node_create(napi);
                                if (err) {
                                        threaded = 
NETDEV_NAPI_THREADED_DISABLED;
                                        break;
@@ -7253,8 +7301,14 @@ int netif_set_threaded(struct net_device *dev,
                WARN_ON_ONCE(napi_set_threaded(napi, threaded));
 
        /* Override the config for all NAPIs even if currently not listed */
-       for (i = 0; i < dev->num_napi_configs; i++)
+       for (i = 0; i < dev->num_napi_configs; i++) {
                dev->napi_config[i].threaded = threaded;
+               /* Stop parked threads in inactive napi_configs */
+               if (!threaded && dev->napi_config[i].thread_node) {
+                       napi_thread_node_stop(dev->napi_config[i].thread_node);
+                       dev->napi_config[i].thread_node = NULL;
+               }
+       }
 
        return err;
 }
@@ -7657,7 +7711,7 @@ void napi_enable_locked(struct napi_struct *n)
                BUG_ON(!test_bit(NAPI_STATE_SCHED, &val));
 
                new = val & ~(NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC);
-               if (n->dev->threaded && n->thread)
+               if (n->dev->threaded && n->thread_node)
                        new |= NAPIF_STATE_THREADED;
        } while (!try_cmpxchg(&n->state, &val, new));
 }
@@ -7682,6 +7736,8 @@ EXPORT_SYMBOL(napi_enable);
 /* Must be called in process context */
 void __netif_napi_del_locked(struct napi_struct *napi)
 {
+       struct napi_thread_node *thread_node;
+
        netdev_assert_locked(napi->dev);
 
        if (!test_and_clear_bit(NAPI_STATE_LISTED, &napi->state))
@@ -7693,6 +7749,18 @@ void __netif_napi_del_locked(struct napi_struct *napi)
        if (test_and_clear_bit(NAPI_STATE_HAS_NOTIFIER, &napi->state))
                irq_set_affinity_notifier(napi->irq, NULL);
 
+       thread_node = netdev_lock_dereference(napi->thread_node, napi->dev);
+       if (thread_node) {
+               rcu_assign_pointer(napi->thread_node, NULL);
+               if (napi->config) {
+                       kthread_park(thread_node->thread);
+                       napi->config->thread_node = thread_node;
+                       napi->config->thread_node->napi = NULL;
+               } else {
+                       napi_thread_node_stop(thread_node);
+               }
+       }
+
        if (napi->config) {
                napi->index = -1;
                napi->config = NULL;
@@ -7702,11 +7770,6 @@ void __netif_napi_del_locked(struct napi_struct *napi)
        napi_free_frags(napi);
 
        gro_cleanup(&napi->gro);
-
-       if (napi->thread) {
-               kthread_stop(napi->thread);
-               napi->thread = NULL;
-       }
 }
 EXPORT_SYMBOL(__netif_napi_del_locked);
 
@@ -7802,11 +7865,21 @@ static int napi_poll(struct napi_struct *n, struct 
list_head *repoll)
        return work;
 }
 
-static int napi_thread_wait(struct napi_struct *napi)
+static struct napi_struct *
+napi_thread_wait(struct napi_thread_node *thread_node)
 {
+       struct napi_struct *napi = READ_ONCE(thread_node->napi);
+
        set_current_state(TASK_INTERRUPTIBLE);
 
        while (!kthread_should_stop()) {
+               if (kthread_should_park()) {
+                       kthread_parkme();
+                       napi = READ_ONCE(thread_node->napi);
+                       /* Might be awakened for stopping */
+                       continue;
+               }
+
                /* Testing SCHED_THREADED bit here to make sure the current
                 * kthread owns this napi and could poll on this napi.
                 * Testing SCHED bit is not enough because SCHED bit might be
@@ -7815,7 +7888,7 @@ static int napi_thread_wait(struct napi_struct *napi)
                if (test_bit(NAPI_STATE_SCHED_THREADED, &napi->state)) {
                        WARN_ON(!list_empty(&napi->poll_list));
                        __set_current_state(TASK_RUNNING);
-                       return 0;
+                       return napi;
                }
 
                schedule();
@@ -7823,7 +7896,7 @@ static int napi_thread_wait(struct napi_struct *napi)
        }
        __set_current_state(TASK_RUNNING);
 
-       return -1;
+       return NULL;
 }
 
 static void napi_threaded_poll_loop(struct napi_struct *napi,
@@ -7880,13 +7953,19 @@ static void napi_threaded_poll_loop(struct napi_struct 
*napi,
 
 static int napi_threaded_poll(void *data)
 {
-       struct napi_struct *napi = data;
+       struct napi_thread_node *thread_node = data;
        unsigned long last_qs = jiffies;
+       struct napi_struct *napi;
        bool want_busy_poll;
        bool in_busy_poll;
        unsigned long val;
 
-       while (!napi_thread_wait(napi)) {
+       while (1) {
+               napi = napi_thread_wait(thread_node);
+
+               if (!napi)
+                       break;
+
                val = READ_ONCE(napi->state);
 
                want_busy_poll = val & NAPIF_STATE_THREADED_BUSY_POLL;
@@ -12155,6 +12234,8 @@ EXPORT_SYMBOL(alloc_netdev_mqs);
 
 static void netdev_napi_exit(struct net_device *dev)
 {
+       unsigned int i;
+
        if (!list_empty(&dev->napi_list)) {
                struct napi_struct *p, *n;
 
@@ -12166,6 +12247,10 @@ static void netdev_napi_exit(struct net_device *dev)
                synchronize_net();
        }
 
+       for (i = 0; i < dev->num_napi_configs; i++) {
+               if (dev->napi_config[i].thread_node)
+                       napi_thread_node_stop(dev->napi_config[i].thread_node);
+       }
        kvfree(dev->napi_config);
 }
 
@@ -13204,12 +13289,12 @@ static void backlog_napi_setup(unsigned int cpu)
        struct softnet_data *sd = per_cpu_ptr(&softnet_data, cpu);
        struct napi_struct *napi = &sd->backlog;
 
-       napi->thread = this_cpu_read(backlog_napi);
+       rcu_assign_pointer(napi->thread_node, this_cpu_ptr(&backlog_napi));
        set_bit(NAPI_STATE_THREADED, &napi->state);
 }
 
 static struct smp_hotplug_thread backlog_threads = {
-       .store                  = &backlog_napi,
+       .store                  = &backlog_napi.thread,
        .thread_should_run      = backlog_napi_should_run,
        .thread_fn              = run_backlog_napi,
        .thread_comm            = "backlog_napi/%u",
diff --git a/net/core/netdev-genl.c b/net/core/netdev-genl.c
index 11b0b91683d7..f2ecdb26d6f1 100644
--- a/net/core/netdev-genl.c
+++ b/net/core/netdev-genl.c
@@ -162,6 +162,7 @@ static int
 netdev_nl_napi_fill_one(struct sk_buff *rsp, struct napi_struct *napi,
                        const struct genl_info *info)
 {
+       struct napi_thread_node *thread_node;
        unsigned long irq_suspend_timeout;
        unsigned long gro_flush_timeout;
        u32 napi_defer_hard_irqs;
@@ -188,11 +189,16 @@ netdev_nl_napi_fill_one(struct sk_buff *rsp, struct 
napi_struct *napi,
                         napi_get_threaded(napi)))
                goto nla_put_failure;
 
-       if (napi->thread) {
-               pid = task_pid_nr(napi->thread);
-               if (nla_put_u32(rsp, NETDEV_A_NAPI_PID, pid))
+       rcu_read_lock();
+       thread_node = rcu_dereference(napi->thread_node);
+       if (thread_node) {
+               pid = task_pid_nr(thread_node->thread);
+               if (nla_put_u32(rsp, NETDEV_A_NAPI_PID, pid)) {
+                       rcu_read_unlock();
                        goto nla_put_failure;
+               }
        }
+       rcu_read_unlock();
 
        napi_defer_hard_irqs = napi_get_defer_hard_irqs(napi);
        if (nla_put_s32(rsp, NETDEV_A_NAPI_DEFER_HARD_IRQS,
-- 
2.54.0.1136.gdb2ca164c4-goog


Reply via email to