From: Paolo Abeni <pab...@redhat.com>

The local address id is accessed lockless by the NL PM, add
all the required ONCE annotation. There is a caveat: the local
id can be initialized late in the subflow life-cycle, and its
validity is controlled by the local_id_valid flag.

Remove such flag and encode the validity in the local_id field
itself with negative value before initialization. That allows
accessing the field consistently with a single read operation.

Fixes: 0ee4261a3681 ("mptcp: implement mptcp_pm_remove_subflow")
Cc: sta...@vger.kernel.org
Signed-off-by: Paolo Abeni <pab...@redhat.com>
Reviewed-by: Mat Martineau <martin...@kernel.org>
Signed-off-by: Matthieu Baerts (NGI0) <matt...@kernel.org>
---
 net/mptcp/diag.c         |  2 +-
 net/mptcp/pm_netlink.c   |  6 +++---
 net/mptcp/pm_userspace.c |  2 +-
 net/mptcp/protocol.c     |  2 +-
 net/mptcp/protocol.h     | 15 ++++++++++++---
 net/mptcp/subflow.c      |  9 +++++----
 6 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/net/mptcp/diag.c b/net/mptcp/diag.c
index e57c5f47f035..6ff6f14674aa 100644
--- a/net/mptcp/diag.c
+++ b/net/mptcp/diag.c
@@ -65,7 +65,7 @@ static int subflow_get_info(struct sock *sk, struct sk_buff 
*skb)
                        sf->map_data_len) ||
            nla_put_u32(skb, MPTCP_SUBFLOW_ATTR_FLAGS, flags) ||
            nla_put_u8(skb, MPTCP_SUBFLOW_ATTR_ID_REM, sf->remote_id) ||
-           nla_put_u8(skb, MPTCP_SUBFLOW_ATTR_ID_LOC, sf->local_id)) {
+           nla_put_u8(skb, MPTCP_SUBFLOW_ATTR_ID_LOC, 
subflow_get_local_id(sf))) {
                err = -EMSGSIZE;
                goto nla_failure;
        }
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c
index a24c9128dee9..912e25077437 100644
--- a/net/mptcp/pm_netlink.c
+++ b/net/mptcp/pm_netlink.c
@@ -800,7 +800,7 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct 
mptcp_sock *msk,
                mptcp_for_each_subflow_safe(msk, subflow, tmp) {
                        struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                        int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
-                       u8 id = subflow->local_id;
+                       u8 id = subflow_get_local_id(subflow);
 
                        if (rm_type == MPTCP_MIB_RMADDR && subflow->remote_id 
!= rm_id)
                                continue;
@@ -809,7 +809,7 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct 
mptcp_sock *msk,
 
                        pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u 
remote_id=%u mpc_id=%u",
                                 rm_type == MPTCP_MIB_RMADDR ? "address" : 
"subflow",
-                                i, rm_id, subflow->local_id, 
subflow->remote_id,
+                                i, rm_id, id, subflow->remote_id,
                                 msk->mpc_endpoint_id);
                        spin_unlock_bh(&msk->pm.lock);
                        mptcp_subflow_shutdown(sk, ssk, how);
@@ -1994,7 +1994,7 @@ static int mptcp_event_add_subflow(struct sk_buff *skb, 
const struct sock *ssk)
        if (WARN_ON_ONCE(!sf))
                return -EINVAL;
 
-       if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
+       if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, subflow_get_local_id(sf)))
                return -EMSGSIZE;
 
        if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c
index e582b3b2d174..d396a5973429 100644
--- a/net/mptcp/pm_userspace.c
+++ b/net/mptcp/pm_userspace.c
@@ -234,7 +234,7 @@ static int mptcp_userspace_pm_remove_id_zero_address(struct 
mptcp_sock *msk,
 
        lock_sock(sk);
        mptcp_for_each_subflow(msk, subflow) {
-               if (subflow->local_id == 0) {
+               if (READ_ONCE(subflow->local_id) == 0) {
                        has_id_0 = true;
                        break;
                }
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 8ef2927ebca2..948606a537da 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -85,7 +85,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk)
        subflow->subflow_id = msk->subflow_id++;
 
        /* This is the first subflow, always with id 0 */
-       subflow->local_id_valid = 1;
+       WRITE_ONCE(subflow->local_id, 0);
        mptcp_sock_graft(msk->first, sk->sk_socket);
        iput(SOCK_INODE(ssock));
 
diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h
index ed50f2015dc3..631a7f445f34 100644
--- a/net/mptcp/protocol.h
+++ b/net/mptcp/protocol.h
@@ -491,10 +491,9 @@ struct mptcp_subflow_context {
                remote_key_valid : 1,        /* received the peer key from */
                disposable : 1,     /* ctx can be free at ulp release time */
                stale : 1,          /* unable to snd/rcv data, do not use for 
xmit */
-               local_id_valid : 1, /* local_id is correctly initialized */
                valid_csum_seen : 1,        /* at least one csum validated */
                is_mptfo : 1,       /* subflow is doing TFO */
-               __unused : 9;
+               __unused : 10;
        bool    data_avail;
        bool    scheduled;
        u32     remote_nonce;
@@ -505,7 +504,7 @@ struct mptcp_subflow_context {
                u8      hmac[MPTCPOPT_HMAC_LEN]; /* MPJ subflow only */
                u64     iasn;       /* initial ack sequence number, MPC 
subflows only */
        };
-       u8      local_id;
+       s16     local_id;           /* if negative not initialized yet */
        u8      remote_id;
        u8      reset_seen:1;
        u8      reset_transient:1;
@@ -556,6 +555,7 @@ mptcp_subflow_ctx_reset(struct mptcp_subflow_context 
*subflow)
 {
        memset(&subflow->reset, 0, sizeof(subflow->reset));
        subflow->request_mptcp = 1;
+       WRITE_ONCE(subflow->local_id, -1);
 }
 
 static inline u64
@@ -1022,6 +1022,15 @@ int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct 
sock_common *skc);
 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info 
*skc);
 int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, struct 
mptcp_addr_info *skc);
 
+static inline u8 subflow_get_local_id(const struct mptcp_subflow_context 
*subflow)
+{
+       int local_id = READ_ONCE(subflow->local_id);
+
+       if (local_id < 0)
+               return 0;
+       return local_id;
+}
+
 void __init mptcp_pm_nl_init(void);
 void mptcp_pm_nl_work(struct mptcp_sock *msk);
 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
index c34ecadee120..015184bbf06c 100644
--- a/net/mptcp/subflow.c
+++ b/net/mptcp/subflow.c
@@ -577,8 +577,8 @@ static void subflow_finish_connect(struct sock *sk, const 
struct sk_buff *skb)
 
 static void subflow_set_local_id(struct mptcp_subflow_context *subflow, int 
local_id)
 {
-       subflow->local_id = local_id;
-       subflow->local_id_valid = 1;
+       WARN_ON_ONCE(local_id < 0 || local_id > 255);
+       WRITE_ONCE(subflow->local_id, local_id);
 }
 
 static int subflow_chk_local_id(struct sock *sk)
@@ -587,7 +587,7 @@ static int subflow_chk_local_id(struct sock *sk)
        struct mptcp_sock *msk = mptcp_sk(subflow->conn);
        int err;
 
-       if (likely(subflow->local_id_valid))
+       if (likely(subflow->local_id >= 0))
                return 0;
 
        err = mptcp_pm_get_local_id(msk, (struct sock_common *)sk);
@@ -1731,6 +1731,7 @@ static struct mptcp_subflow_context 
*subflow_create_ctx(struct sock *sk,
        pr_debug("subflow=%p", ctx);
 
        ctx->tcp_sock = sk;
+       WRITE_ONCE(ctx->local_id, -1);
 
        return ctx;
 }
@@ -1966,7 +1967,7 @@ static void subflow_ulp_clone(const struct request_sock 
*req,
                new_ctx->idsn = subflow_req->idsn;
 
                /* this is the first subflow, id is always 0 */
-               new_ctx->local_id_valid = 1;
+               subflow_set_local_id(new_ctx, 0);
        } else if (subflow_req->mp_join) {
                new_ctx->ssn_offset = subflow_req->ssn_offset;
                new_ctx->mp_join = 1;

-- 
2.43.0


Reply via email to