On Thu, Apr 27, 2017 at 2:35 PM, Cong Wang <[email protected]> wrote:
> On Thu, Apr 27, 2017 at 1:31 PM, Cong Wang <[email protected]> wrote:
>> On Wed, Apr 26, 2017 at 2:20 PM, Paul Moore <[email protected]> wrote:
>>> Thanks for the report, this is the only one like it that I've seen.
>>> I'm looking at the code in Linus' tree and I'm not seeing anything
>>> obvious ... looking at the trace above it appears that the problem is
>>> when get_net() goes to bump the refcount and the passed net pointer is
>>> NULL; unless I'm missing something, the only way this would happen in
>>> kauditd_thread() is if the auditd_conn.pid value is non-zero but the
>>> auditd_conn.net pointer is NULL.
>>>
>>> That shouldn't happen.
>>>
>>
>> Looking at the code that reads/writes the global auditd_conn,
>> I don't see how it even works with RCU+spinlock, RCU plays
>> with pointers and you have to make a copy as its name implies.
>> But it looks like you simply use RCU+spinlock as a traditional
>> rwlock, it doesn't work.
>
> The attached patch seems working for me, I tried to boot my
> VM for 4 times, so far no crash or warning.
>

Or even better, save a memory allocation for reset path...
diff --git a/kernel/audit.c b/kernel/audit.c
index a871bf8..9953dbe 100644
--- a/kernel/audit.c
+++ b/kernel/audit.c
@@ -110,7 +110,7 @@ struct audit_net {
  * @pid: auditd PID
  * @portid: netlink portid
  * @net: the associated network namespace
- * @lock: spinlock to protect write access
+ * @rcu: rcu head
  *
  * Description:
  * This struct is RCU protected; you must either hold the RCU lock for reading
@@ -120,8 +120,9 @@ static struct auditd_connection {
        int pid;
        u32 portid;
        struct net *net;
-       spinlock_t lock;
-} auditd_conn;
+       struct rcu_head rcu;
+} *auditd_conn, null_conn;
+static DEFINE_SPINLOCK(auditd_conn_lock);
 
 /* If audit_rate_limit is non-zero, limit the rate of sending audit records
  * to that number per second.  This prevents DoS attacks, but results in
@@ -223,9 +224,11 @@ struct audit_reply {
 int auditd_test_task(const struct task_struct *task)
 {
        int rc;
+       pid_t pid;
 
        rcu_read_lock();
-       rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0);
+       pid = rcu_dereference(auditd_conn)->pid;
+       rc = (pid && task->tgid == pid ? 1 : 0);
        rcu_read_unlock();
 
        return rc;
@@ -426,30 +429,39 @@ static int audit_set_failure(u32 state)
        return audit_do_config_change("audit_failure", &audit_failure, state);
 }
 
+static void auditd_conn_free(struct rcu_head *rcu)
+{
+       struct auditd_connection *cn = container_of(rcu, struct 
auditd_connection, rcu);
+
+       if (cn == &null_conn)
+               return;
+       if (cn->net)
+               put_net(cn->net);
+       kfree(cn);
+}
+
 /**
- * auditd_set - Set/Reset the auditd connection state
- * @pid: auditd PID
- * @portid: auditd netlink portid
- * @net: auditd network namespace pointer
+ * auditd_set - Set the auditd connection state
+ * @new_coon: the new auditd connection
  *
  * Description:
  * This function will obtain and drop network namespace references as
  * necessary.
  */
-static void auditd_set(int pid, u32 portid, struct net *net)
+static void auditd_set(struct auditd_connection *new_conn)
 {
+       struct auditd_connection *old_conn;
        unsigned long flags;
 
-       spin_lock_irqsave(&auditd_conn.lock, flags);
-       auditd_conn.pid = pid;
-       auditd_conn.portid = portid;
-       if (auditd_conn.net)
-               put_net(auditd_conn.net);
-       if (net)
-               auditd_conn.net = get_net(net);
-       else
-               auditd_conn.net = NULL;
-       spin_unlock_irqrestore(&auditd_conn.lock, flags);
+       if (new_conn->net)
+               get_net(new_conn->net);
+       spin_lock_irqsave(&auditd_conn_lock, flags);
+       old_conn = rcu_dereference_protected(auditd_conn,
+                                       lockdep_is_held(&auditd_conn_lock));
+       rcu_assign_pointer(auditd_conn, new_conn);
+       spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+       call_rcu(&old_conn->rcu, auditd_conn_free);
 }
 
 /**
@@ -548,8 +560,8 @@ static void auditd_reset(void)
 
        /* if it isn't already broken, break the connection */
        rcu_read_lock();
-       if (auditd_conn.pid)
-               auditd_set(0, 0, NULL);
+       if (rcu_dereference(auditd_conn)->pid)
+               auditd_set(&null_conn);
        rcu_read_unlock();
 
        /* flush all of the main and retry queues to the hold queue */
@@ -585,15 +597,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
         *       section netlink_unicast() should safely return an error */
 
        rcu_read_lock();
-       if (!auditd_conn.pid) {
+       if (!rcu_dereference(auditd_conn)->pid) {
                rcu_read_unlock();
                rc = -ECONNREFUSED;
                goto err;
        }
-       net = auditd_conn.net;
+       net = rcu_dereference(auditd_conn)->net;
        get_net(net);
        sk = audit_get_sk(net);
-       portid = auditd_conn.portid;
+       portid = rcu_dereference(auditd_conn)->portid;
        rcu_read_unlock();
 
        rc = netlink_unicast(sk, skb, portid, 0);
@@ -735,14 +747,14 @@ static int kauditd_thread(void *dummy)
        while (!kthread_should_stop()) {
                /* NOTE: see the lock comments in auditd_send_unicast_skb() */
                rcu_read_lock();
-               if (!auditd_conn.pid) {
+               if (!rcu_dereference(auditd_conn)->pid) {
                        rcu_read_unlock();
                        goto main_queue;
                }
-               net = auditd_conn.net;
+               net = rcu_dereference(auditd_conn)->net;
                get_net(net);
                sk = audit_get_sk(net);
-               portid = auditd_conn.portid;
+               portid = rcu_dereference(auditd_conn)->portid;
                rcu_read_unlock();
 
                /* attempt to flush the hold queue */
@@ -1103,7 +1115,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
                s.enabled               = audit_enabled;
                s.failure               = audit_failure;
                rcu_read_lock();
-               s.pid                   = auditd_conn.pid;
+               s.pid                   = rcu_dereference(auditd_conn)->pid;
                rcu_read_unlock();
                s.rate_limit            = audit_rate_limit;
                s.backlog_limit         = audit_backlog_limit;
@@ -1139,17 +1151,22 @@ static int audit_receive_msg(struct sk_buff *skb, 
struct nlmsghdr *nlh)
                        int new_pid = s.pid;
                        pid_t auditd_pid;
                        pid_t requesting_pid = task_tgid_vnr(current);
+                       struct auditd_connection *new;
 
+                       new = kmalloc(sizeof(*new), GFP_KERNEL);
+                       if (!new)
+                               return -ENOMEM;
                        /* test the auditd connection */
                        audit_replace(requesting_pid);
 
                        rcu_read_lock();
-                       auditd_pid = auditd_conn.pid;
+                       auditd_pid = rcu_dereference(auditd_conn)->pid;
                        /* only the current auditd can unregister itself */
                        if ((!new_pid) && (requesting_pid != auditd_pid)) {
                                rcu_read_unlock();
                                audit_log_config_change("audit_pid", new_pid,
                                                        auditd_pid, 0);
+                               kfree(new);
                                return -EACCES;
                        }
                        /* replacing a healthy auditd is not allowed */
@@ -1157,6 +1174,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
                                rcu_read_unlock();
                                audit_log_config_change("audit_pid", new_pid,
                                                        auditd_pid, 0);
+                               kfree(new);
                                return -EEXIST;
                        }
                        rcu_read_unlock();
@@ -1166,10 +1184,11 @@ static int audit_receive_msg(struct sk_buff *skb, 
struct nlmsghdr *nlh)
                                                        auditd_pid, 1);
 
                        if (new_pid) {
+                               new->pid = new_pid;
+                               new->portid = NETLINK_CB(skb).portid;
+                               new->net = sock_net(NETLINK_CB(skb).sk);
                                /* register a new auditd connection */
-                               auditd_set(new_pid,
-                                          NETLINK_CB(skb).portid,
-                                          sock_net(NETLINK_CB(skb).sk));
+                               auditd_set(new);
                                /* try to process any backlog */
                                wake_up_interruptible(&kauditd_wait);
                        } else
@@ -1448,7 +1467,7 @@ static void __net_exit audit_net_exit(struct net *net)
        struct audit_net *aunet = net_generic(net, audit_net_id);
 
        rcu_read_lock();
-       if (net == auditd_conn.net)
+       if (net == rcu_dereference(auditd_conn)->net)
                auditd_reset();
        rcu_read_unlock();
 
@@ -1470,8 +1489,9 @@ static int __init audit_init(void)
        if (audit_initialized == AUDIT_DISABLED)
                return 0;
 
-       memset(&auditd_conn, 0, sizeof(auditd_conn));
-       spin_lock_init(&auditd_conn.lock);
+       auditd_conn = kzalloc(sizeof(*auditd_conn), GFP_KERNEL);
+       if (!auditd_conn)
+               return -ENOMEM;
 
        skb_queue_head_init(&audit_queue);
        skb_queue_head_init(&audit_retry_queue);

Reply via email to