An rds_connection can get added during netns deletion between lines 528
and 529 of

  506 static void rds_tcp_kill_sock(struct net *net)
  :
  /* code to pull out all the rds_connections that should be destroyed */
  :
  528         spin_unlock_irq(&rds_tcp_conn_lock);
  529         list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
  530                 rds_conn_destroy(tc->t_cpath->cp_conn);

Such an rds_connection would miss out the rds_conn_destroy()
loop (that cancels all pending work) and (if it was scheduled
after netns deletion) could trigger the use-after-free.

A similar race-window exists for the module unload path
in rds_tcp_exit -> rds_tcp_destroy_conns

To avoid the addition of new rds_connections during kill_sock
or netns_delete, this patch introduces a per-netns flag,
RTN_DELETE_PENDING, that will cause RDS connection creation to fail.
RCU is used to make sure that we wait for the critical
section of __rds_conn_create threads (that may have started before
the setting of RTN_DELETE_PENDING) to complete before starting
the connection destruction.

Reported-by: syzbot+bbd8e9a06452cc480...@syzkaller.appspotmail.com
Signed-off-by: Sowmini Varadhan <sowmini.varad...@oracle.com>
---
 net/rds/connection.c |    3 ++
 net/rds/tcp.c        |   82 ++++++++++++++++++++++++++++++++-----------------
 net/rds/tcp.h        |    1 +
 3 files changed, 57 insertions(+), 29 deletions(-)

diff --git a/net/rds/connection.c b/net/rds/connection.c
index b10c0ef..2ae539d 100644
--- a/net/rds/connection.c
+++ b/net/rds/connection.c
@@ -220,8 +220,10 @@ static void __rds_conn_path_init(struct rds_connection 
*conn,
                                     is_outgoing);
                conn->c_path[i].cp_index = i;
        }
+       rcu_read_lock();
        ret = trans->conn_alloc(conn, gfp);
        if (ret) {
+               rcu_read_unlock();
                kfree(conn->c_path);
                kmem_cache_free(rds_conn_slab, conn);
                conn = ERR_PTR(ret);
@@ -283,6 +285,7 @@ static void __rds_conn_path_init(struct rds_connection 
*conn,
                }
        }
        spin_unlock_irqrestore(&rds_conn_lock, flags);
+       rcu_read_unlock();
 
 out:
        return conn;
diff --git a/net/rds/tcp.c b/net/rds/tcp.c
index 9920d2f..2bdd3cc 100644
--- a/net/rds/tcp.c
+++ b/net/rds/tcp.c
@@ -274,14 +274,13 @@ static int rds_tcp_laddr_check(struct net *net, __be32 
addr)
 static void rds_tcp_conn_free(void *arg)
 {
        struct rds_tcp_connection *tc = arg;
-       unsigned long flags;
 
        rdsdebug("freeing tc %p\n", tc);
 
-       spin_lock_irqsave(&rds_tcp_conn_lock, flags);
+       spin_lock_bh(&rds_tcp_conn_lock);
        if (!tc->t_tcp_node_detached)
                list_del(&tc->t_tcp_node);
-       spin_unlock_irqrestore(&rds_tcp_conn_lock, flags);
+       spin_unlock_bh(&rds_tcp_conn_lock);
 
        kmem_cache_free(rds_tcp_conn_slab, tc);
 }
@@ -296,7 +295,7 @@ static int rds_tcp_conn_alloc(struct rds_connection *conn, 
gfp_t gfp)
                tc = kmem_cache_alloc(rds_tcp_conn_slab, gfp);
                if (!tc) {
                        ret = -ENOMEM;
-                       break;
+                       goto fail;
                }
                mutex_init(&tc->t_conn_path_lock);
                tc->t_sock = NULL;
@@ -306,14 +305,25 @@ static int rds_tcp_conn_alloc(struct rds_connection 
*conn, gfp_t gfp)
 
                conn->c_path[i].cp_transport_data = tc;
                tc->t_cpath = &conn->c_path[i];
+               tc->t_tcp_node_detached = true;
 
-               spin_lock_irq(&rds_tcp_conn_lock);
-               tc->t_tcp_node_detached = false;
-               list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list);
-               spin_unlock_irq(&rds_tcp_conn_lock);
                rdsdebug("rds_conn_path [%d] tc %p\n", i,
                         conn->c_path[i].cp_transport_data);
        }
+       spin_lock_bh(&rds_tcp_conn_lock);
+       if (rds_tcp_netns_delete_pending(rds_conn_net(conn))) {
+               rdsdebug("RTN_DELETE_PENDING\n");
+               ret = -ENETDOWN;
+               spin_unlock_bh(&rds_tcp_conn_lock);
+               goto fail;
+       }
+       for (i = 0; i < RDS_MPATH_WORKERS; i++) {
+               tc = conn->c_path[i].cp_transport_data;
+               tc->t_tcp_node_detached = false;
+               list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list);
+       }
+       spin_unlock_bh(&rds_tcp_conn_lock);
+fail:
        if (ret) {
                for (j = 0; j < i; j++)
                        rds_tcp_conn_free(conn->c_path[j].cp_transport_data);
@@ -332,23 +342,6 @@ static bool list_has_conn(struct list_head *list, struct 
rds_connection *conn)
        return false;
 }
 
-static void rds_tcp_destroy_conns(void)
-{
-       struct rds_tcp_connection *tc, *_tc;
-       LIST_HEAD(tmp_list);
-
-       /* avoid calling conn_destroy with irqs off */
-       spin_lock_irq(&rds_tcp_conn_lock);
-       list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
-               if (!list_has_conn(&tmp_list, tc->t_cpath->cp_conn))
-                       list_move_tail(&tc->t_tcp_node, &tmp_list);
-       }
-       spin_unlock_irq(&rds_tcp_conn_lock);
-
-       list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
-               rds_conn_destroy(tc->t_cpath->cp_conn);
-}
-
 static void rds_tcp_exit(void);
 
 struct rds_transport rds_tcp_transport = {
@@ -382,8 +375,30 @@ struct rds_tcp_net {
        struct ctl_table *ctl_table;
        int sndbuf_size;
        int rcvbuf_size;
+       unsigned long  rtn_flags;
+#define        RTN_DELETE_PENDING      0
 };
 
+static void rds_tcp_destroy_conns(void)
+{
+       struct rds_tcp_connection *tc, *_tc;
+       struct rds_tcp_net *rtn = net_generic(&init_net, rds_tcp_netid);
+       LIST_HEAD(tmp_list);
+
+       /* avoid calling conn_destroy with irqs off */
+       set_bit(RTN_DELETE_PENDING, &rtn->rtn_flags);
+       synchronize_rcu();
+       spin_lock_bh(&rds_tcp_conn_lock);
+       list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
+               if (!list_has_conn(&tmp_list, tc->t_cpath->cp_conn))
+                       list_move_tail(&tc->t_tcp_node, &tmp_list);
+       }
+       spin_unlock_bh(&rds_tcp_conn_lock);
+
+       list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
+               rds_conn_destroy(tc->t_cpath->cp_conn);
+}
+
 /* All module specific customizations to the RDS-TCP socket should be done in
  * rds_tcp_tune() and applied after socket creation.
  */
@@ -504,6 +519,13 @@ static void __net_exit rds_tcp_exit_net(struct net *net)
        .size = sizeof(struct rds_tcp_net),
 };
 
+bool rds_tcp_netns_delete_pending(struct net *net)
+{
+       struct rds_tcp_net *rtn = net_generic(net, rds_tcp_netid);
+
+       return test_bit(RTN_DELETE_PENDING, &rtn->rtn_flags);
+}
+
 static void rds_tcp_kill_sock(struct net *net)
 {
        struct rds_tcp_connection *tc, *_tc;
@@ -513,7 +535,9 @@ static void rds_tcp_kill_sock(struct net *net)
 
        rtn->rds_tcp_listen_sock = NULL;
        rds_tcp_listen_stop(lsock, &rtn->rds_tcp_accept_w);
-       spin_lock_irq(&rds_tcp_conn_lock);
+       set_bit(RTN_DELETE_PENDING, &rtn->rtn_flags);
+       synchronize_rcu();
+       spin_lock_bh(&rds_tcp_conn_lock);
        list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
                struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net);
 
@@ -526,7 +550,7 @@ static void rds_tcp_kill_sock(struct net *net)
                        tc->t_tcp_node_detached = true;
                }
        }
-       spin_unlock_irq(&rds_tcp_conn_lock);
+       spin_unlock_bh(&rds_tcp_conn_lock);
        list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
                rds_conn_destroy(tc->t_cpath->cp_conn);
 }
@@ -574,7 +598,7 @@ static void rds_tcp_sysctl_reset(struct net *net)
 {
        struct rds_tcp_connection *tc, *_tc;
 
-       spin_lock_irq(&rds_tcp_conn_lock);
+       spin_lock_bh(&rds_tcp_conn_lock);
        list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
                struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net);
 
@@ -584,7 +608,7 @@ static void rds_tcp_sysctl_reset(struct net *net)
                /* reconnect with new parameters */
                rds_conn_path_drop(tc->t_cpath, false);
        }
-       spin_unlock_irq(&rds_tcp_conn_lock);
+       spin_unlock_bh(&rds_tcp_conn_lock);
 }
 
 static int rds_tcp_skbuf_handler(struct ctl_table *ctl, int write,
diff --git a/net/rds/tcp.h b/net/rds/tcp.h
index c6fa080..b07dbd7 100644
--- a/net/rds/tcp.h
+++ b/net/rds/tcp.h
@@ -60,6 +60,7 @@ void rds_tcp_restore_callbacks(struct socket *sock,
 u64 rds_tcp_map_seq(struct rds_tcp_connection *tc, u32 seq);
 extern struct rds_transport rds_tcp_transport;
 void rds_tcp_accept_work(struct sock *sk);
+bool rds_tcp_netns_delete_pending(struct net *net);
 
 /* tcp_connect.c */
 int rds_tcp_conn_path_connect(struct rds_conn_path *cp);
-- 
1.7.1

Reply via email to