The AI-generated review reported a potential DMA use-after-free issue
[1]. If netc_xmit_ntmp_cmd() times out and returns an error, the pending
command is not explicitly aborted, while ntmp_free_data_mem()
unconditionally frees the DMA buffer. If the buffer has already been
reallocated elsewhere, this may lead to silent memory corruption. Because
the hardware eventually processes the pending command and perform a DMA
write of the response to the physical address of the freed buffer.

To resolve this issue, this patch does the following modifications:

1. Convert cbdr->ring_lock from a spinlock to a mutex

The lock was originally a spinlock in case NTMP operations might be
invoked from atomic context. After downstream support for all NTMP
tables, no such usage has materialized. A mutex lock is now required
because the driver now needs to reclaim used BDs and release associated
DMA memory within the lock's context, while dma_free_coherent() might
sleep.

2. Introduce software command BD (struct netc_swcbd)

The hardware write-back overwrites the addr and len fields of the BD,
so the driver cannot rely on the hardware BD to free the associated DMA
memory. The driver now maintains a software shadow BD storing the DMA
buffer pointer, DMA address, and size. And netc_xmit_ntmp_cmd() only
reclaims older BDs when the number of used BDs reaches
NETC_CBDR_CLEAN_WORK (16). The software BD enables correct DMA memory
release. With this, struct ntmp_dma_buf and ntmp_free_data_mem() are no
longer needed and are removed.

These changes eliminate the DMA use-after-free condition and ensure safe
and consistent BD reclamation and DMA buffer lifecycle management.

Fixes: 4701073c3deb ("net: enetc: add initial netc-lib driver to support NTMP")
Link: https://lore.kernel.org/netdev/[email protected]/ 
# [1]
Signed-off-by: Wei Fang <[email protected]>
---
 drivers/net/ethernet/freescale/enetc/ntmp.c   | 158 ++++++++++--------
 .../ethernet/freescale/enetc/ntmp_private.h   |   8 +-
 include/linux/fsl/ntmp.h                      |   9 +-
 3 files changed, 93 insertions(+), 82 deletions(-)

diff --git a/drivers/net/ethernet/freescale/enetc/ntmp.c 
b/drivers/net/ethernet/freescale/enetc/ntmp.c
index 1b1ff0446d0a..3efc65443113 100644
--- a/drivers/net/ethernet/freescale/enetc/ntmp.c
+++ b/drivers/net/ethernet/freescale/enetc/ntmp.c
@@ -7,6 +7,7 @@
 #include <linux/dma-mapping.h>
 #include <linux/fsl/netc_global.h>
 #include <linux/iopoll.h>
+#include <linux/vmalloc.h>
 
 #include "ntmp_private.h"
 
@@ -42,6 +43,12 @@ int ntmp_init_cbdr(struct netc_cbdr *cbdr, struct device 
*dev,
        if (!cbdr->addr_base)
                return -ENOMEM;
 
+       cbdr->swcbd = vcalloc(cbd_num, sizeof(struct netc_swcbd));
+       if (!cbdr->swcbd) {
+               dma_free_coherent(dev, size, cbdr->addr_base, cbdr->dma_base);
+               return -ENOMEM;
+       }
+
        cbdr->dma_size = size;
        cbdr->bd_num = cbd_num;
        cbdr->regs = *regs;
@@ -52,7 +59,7 @@ int ntmp_init_cbdr(struct netc_cbdr *cbdr, struct device *dev,
        cbdr->addr_base_align = PTR_ALIGN(cbdr->addr_base,
                                          NTMP_BASE_ADDR_ALIGN);
 
-       spin_lock_init(&cbdr->ring_lock);
+       mutex_init(&cbdr->ring_lock);
 
        cbdr->next_to_use = netc_read(cbdr->regs.pir);
        cbdr->next_to_clean = netc_read(cbdr->regs.cir) & NETC_CBDRCIR_INDEX;
@@ -71,10 +78,25 @@ int ntmp_init_cbdr(struct netc_cbdr *cbdr, struct device 
*dev,
 }
 EXPORT_SYMBOL_GPL(ntmp_init_cbdr);
 
+static void ntmp_free_data_mem(struct device *dev, struct netc_swcbd *swcbd)
+{
+       dma_free_coherent(dev, swcbd->size + NTMP_DATA_ADDR_ALIGN,
+                         swcbd->buf, swcbd->dma);
+}
+
 void ntmp_free_cbdr(struct netc_cbdr *cbdr)
 {
        /* Disable the Control BD Ring */
        netc_write(cbdr->regs.mr, 0);
+
+       for (int i = 0; i < cbdr->bd_num; i++) {
+               struct netc_swcbd *swcbd = &cbdr->swcbd[i];
+
+               if (swcbd->dma)
+                       ntmp_free_data_mem(cbdr->dev, swcbd);
+       }
+
+       vfree(cbdr->swcbd);
        dma_free_coherent(cbdr->dev, cbdr->dma_size, cbdr->addr_base,
                          cbdr->dma_base);
        memset(cbdr, 0, sizeof(*cbdr));
@@ -94,24 +116,28 @@ static union netc_cbd *ntmp_get_cbd(struct netc_cbdr 
*cbdr, int index)
 
 static void ntmp_clean_cbdr(struct netc_cbdr *cbdr)
 {
-       union netc_cbd *cbd;
-       int i;
+       int i = cbdr->next_to_clean;
 
-       i = cbdr->next_to_clean;
        while ((netc_read(cbdr->regs.cir) & NETC_CBDRCIR_INDEX) != i) {
-               cbd = ntmp_get_cbd(cbdr, i);
+               union netc_cbd *cbd = ntmp_get_cbd(cbdr, i);
+               struct netc_swcbd *swcbd = &cbdr->swcbd[i];
+
+               ntmp_free_data_mem(cbdr->dev, swcbd);
+               memset(swcbd, 0, sizeof(*swcbd));
                memset(cbd, 0, sizeof(*cbd));
                i = (i + 1) % cbdr->bd_num;
        }
 
+       dma_wmb();
        cbdr->next_to_clean = i;
 }
 
-static int netc_xmit_ntmp_cmd(struct ntmp_user *user, union netc_cbd *cbd)
+static int netc_xmit_ntmp_cmd(struct ntmp_user *user, union netc_cbd *cbd,
+                             struct netc_swcbd *swcbd)
 {
        union netc_cbd *cur_cbd;
        struct netc_cbdr *cbdr;
-       int i, err;
+       int i, err, used_bds;
        u16 status;
        u32 val;
 
@@ -120,14 +146,21 @@ static int netc_xmit_ntmp_cmd(struct ntmp_user *user, 
union netc_cbd *cbd)
         */
        cbdr = &user->ring[0];
 
-       spin_lock_bh(&cbdr->ring_lock);
+       mutex_lock(&cbdr->ring_lock);
 
-       if (unlikely(!ntmp_get_free_cbd_num(cbdr)))
+       used_bds = cbdr->bd_num - ntmp_get_free_cbd_num(cbdr);
+       if (unlikely(used_bds >= NETC_CBDR_CLEAN_WORK)) {
                ntmp_clean_cbdr(cbdr);
+               if (unlikely(!ntmp_get_free_cbd_num(cbdr))) {
+                       err = -EBUSY;
+                       goto cbdr_unlock;
+               }
+       }
 
        i = cbdr->next_to_use;
        cur_cbd = ntmp_get_cbd(cbdr, i);
        *cur_cbd = *cbd;
+       cbdr->swcbd[i] = *swcbd;
        dma_wmb();
 
        /* Update producer index of both software and hardware */
@@ -135,10 +168,9 @@ static int netc_xmit_ntmp_cmd(struct ntmp_user *user, 
union netc_cbd *cbd)
        cbdr->next_to_use = i;
        netc_write(cbdr->regs.pir, i);
 
-       err = read_poll_timeout_atomic(netc_read, val,
-                                      (val & NETC_CBDRCIR_INDEX) == i,
-                                      NETC_CBDR_DELAY_US, NETC_CBDR_TIMEOUT,
-                                      true, cbdr->regs.cir);
+       err = read_poll_timeout(netc_read, val, (val & NETC_CBDRCIR_INDEX) == i,
+                               NETC_CBDR_DELAY_US, NETC_CBDR_TIMEOUT,
+                               true, cbdr->regs.cir);
        if (unlikely(err))
                goto cbdr_unlock;
 
@@ -155,36 +187,28 @@ static int netc_xmit_ntmp_cmd(struct ntmp_user *user, 
union netc_cbd *cbd)
                dev_err(user->dev, "Command BD error: 0x%04x\n", status);
        }
 
-       ntmp_clean_cbdr(cbdr);
-       dma_wmb();
-
 cbdr_unlock:
-       spin_unlock_bh(&cbdr->ring_lock);
+       mutex_unlock(&cbdr->ring_lock);
 
        return err;
 }
 
-static int ntmp_alloc_data_mem(struct ntmp_dma_buf *data, void **buf_align)
+static int ntmp_alloc_data_mem(struct device *dev, struct netc_swcbd *swcbd,
+                              void **buf_align)
 {
        void *buf;
 
-       buf = dma_alloc_coherent(data->dev, data->size + NTMP_DATA_ADDR_ALIGN,
-                                &data->dma, GFP_KERNEL);
+       buf = dma_alloc_coherent(dev, swcbd->size + NTMP_DATA_ADDR_ALIGN,
+                                &swcbd->dma, GFP_KERNEL);
        if (!buf)
                return -ENOMEM;
 
-       data->buf = buf;
+       swcbd->buf = buf;
        *buf_align = PTR_ALIGN(buf, NTMP_DATA_ADDR_ALIGN);
 
        return 0;
 }
 
-static void ntmp_free_data_mem(struct ntmp_dma_buf *data)
-{
-       dma_free_coherent(data->dev, data->size + NTMP_DATA_ADDR_ALIGN,
-                         data->buf, data->dma);
-}
-
 static void ntmp_fill_request_hdr(union netc_cbd *cbd, dma_addr_t dma,
                                  int len, int table_id, int cmd,
                                  int access_method)
@@ -235,37 +259,36 @@ static int ntmp_delete_entry_by_id(struct ntmp_user 
*user, int tbl_id,
                                   u8 tbl_ver, u32 entry_id, u32 req_len,
                                   u32 resp_len)
 {
-       struct ntmp_dma_buf data = {
-               .dev = user->dev,
+       struct netc_swcbd swcbd = {
                .size = max(req_len, resp_len),
        };
        struct ntmp_req_by_eid *req;
        union netc_cbd cbd;
        int err;
 
-       err = ntmp_alloc_data_mem(&data, (void **)&req);
+       err = ntmp_alloc_data_mem(user->dev, &swcbd, (void **)&req);
        if (err)
                return err;
 
        ntmp_fill_crd_eid(req, tbl_ver, 0, 0, entry_id);
-       ntmp_fill_request_hdr(&cbd, data.dma, NTMP_LEN(req_len, resp_len),
+       ntmp_fill_request_hdr(&cbd, swcbd.dma, NTMP_LEN(req_len, resp_len),
                              tbl_id, NTMP_CMD_DELETE, NTMP_AM_ENTRY_ID);
 
-       err = netc_xmit_ntmp_cmd(user, &cbd);
+       err = netc_xmit_ntmp_cmd(user, &cbd, &swcbd);
        if (err)
                dev_err(user->dev,
                        "Failed to delete entry 0x%x of %s, err: %pe",
                        entry_id, ntmp_table_name(tbl_id), ERR_PTR(err));
 
-       ntmp_free_data_mem(&data);
-
        return err;
 }
 
 static int ntmp_query_entry_by_id(struct ntmp_user *user, int tbl_id,
-                                 u32 len, struct ntmp_req_by_eid *req,
-                                 dma_addr_t dma, bool compare_eid)
+                                 struct ntmp_req_by_eid *req,
+                                 struct netc_swcbd *swcbd,
+                                 bool compare_eid)
 {
+       u32 len = NTMP_LEN(sizeof(*req), swcbd->size);
        struct ntmp_cmn_resp_query *resp;
        int cmd = NTMP_CMD_QUERY;
        union netc_cbd cbd;
@@ -277,8 +300,9 @@ static int ntmp_query_entry_by_id(struct ntmp_user *user, 
int tbl_id,
                cmd = NTMP_CMD_QU;
 
        /* Request header */
-       ntmp_fill_request_hdr(&cbd, dma, len, tbl_id, cmd, NTMP_AM_ENTRY_ID);
-       err = netc_xmit_ntmp_cmd(user, &cbd);
+       ntmp_fill_request_hdr(&cbd, swcbd->dma, len, tbl_id, cmd,
+                             NTMP_AM_ENTRY_ID);
+       err = netc_xmit_ntmp_cmd(user, &cbd, swcbd);
        if (err) {
                dev_err(user->dev,
                        "Failed to query entry 0x%x of %s, err: %pe\n",
@@ -306,15 +330,14 @@ static int ntmp_query_entry_by_id(struct ntmp_user *user, 
int tbl_id,
 int ntmp_maft_add_entry(struct ntmp_user *user, u32 entry_id,
                        struct maft_entry_data *maft)
 {
-       struct ntmp_dma_buf data = {
-               .dev = user->dev,
+       struct netc_swcbd swcbd = {
                .size = sizeof(struct maft_req_add),
        };
        struct maft_req_add *req;
        union netc_cbd cbd;
        int err;
 
-       err = ntmp_alloc_data_mem(&data, (void **)&req);
+       err = ntmp_alloc_data_mem(user->dev, &swcbd, (void **)&req);
        if (err)
                return err;
 
@@ -323,15 +346,13 @@ int ntmp_maft_add_entry(struct ntmp_user *user, u32 
entry_id,
        req->keye = maft->keye;
        req->cfge = maft->cfge;
 
-       ntmp_fill_request_hdr(&cbd, data.dma, NTMP_LEN(data.size, 0),
+       ntmp_fill_request_hdr(&cbd, swcbd.dma, NTMP_LEN(swcbd.size, 0),
                              NTMP_MAFT_ID, NTMP_CMD_ADD, NTMP_AM_ENTRY_ID);
-       err = netc_xmit_ntmp_cmd(user, &cbd);
+       err = netc_xmit_ntmp_cmd(user, &cbd, &swcbd);
        if (err)
                dev_err(user->dev, "Failed to add MAFT entry 0x%x, err: %pe\n",
                        entry_id, ERR_PTR(err));
 
-       ntmp_free_data_mem(&data);
-
        return err;
 }
 EXPORT_SYMBOL_GPL(ntmp_maft_add_entry);
@@ -339,33 +360,27 @@ EXPORT_SYMBOL_GPL(ntmp_maft_add_entry);
 int ntmp_maft_query_entry(struct ntmp_user *user, u32 entry_id,
                          struct maft_entry_data *maft)
 {
-       struct ntmp_dma_buf data = {
-               .dev = user->dev,
+       struct netc_swcbd swcbd = {
                .size = sizeof(struct maft_resp_query),
        };
        struct maft_resp_query *resp;
        struct ntmp_req_by_eid *req;
        int err;
 
-       err = ntmp_alloc_data_mem(&data, (void **)&req);
+       err = ntmp_alloc_data_mem(user->dev, &swcbd, (void **)&req);
        if (err)
                return err;
 
        ntmp_fill_crd_eid(req, user->tbl.maft_ver, 0, 0, entry_id);
-       err = ntmp_query_entry_by_id(user, NTMP_MAFT_ID,
-                                    NTMP_LEN(sizeof(*req), data.size),
-                                    req, data.dma, true);
+       err = ntmp_query_entry_by_id(user, NTMP_MAFT_ID, req, &swcbd, true);
        if (err)
-               goto end;
+               return err;
 
        resp = (struct maft_resp_query *)req;
        maft->keye = resp->keye;
        maft->cfge = resp->cfge;
 
-end:
-       ntmp_free_data_mem(&data);
-
-       return err;
+       return 0;
 }
 EXPORT_SYMBOL_GPL(ntmp_maft_query_entry);
 
@@ -379,8 +394,8 @@ EXPORT_SYMBOL_GPL(ntmp_maft_delete_entry);
 int ntmp_rsst_update_entry(struct ntmp_user *user, const u32 *table,
                           int count)
 {
-       struct ntmp_dma_buf data = {.dev = user->dev};
        struct rsst_req_update *req;
+       struct netc_swcbd swcbd;
        union netc_cbd cbd;
        int err, i;
 
@@ -388,8 +403,8 @@ int ntmp_rsst_update_entry(struct ntmp_user *user, const 
u32 *table,
                /* HW only takes in a full 64 entry table */
                return -EINVAL;
 
-       data.size = struct_size(req, groups, count);
-       err = ntmp_alloc_data_mem(&data, (void **)&req);
+       swcbd.size = struct_size(req, groups, count);
+       err = ntmp_alloc_data_mem(user->dev, &swcbd, (void **)&req);
        if (err)
                return err;
 
@@ -399,24 +414,22 @@ int ntmp_rsst_update_entry(struct ntmp_user *user, const 
u32 *table,
        for (i = 0; i < count; i++)
                req->groups[i] = (u8)(table[i]);
 
-       ntmp_fill_request_hdr(&cbd, data.dma, NTMP_LEN(data.size, 0),
+       ntmp_fill_request_hdr(&cbd, swcbd.dma, NTMP_LEN(swcbd.size, 0),
                              NTMP_RSST_ID, NTMP_CMD_UPDATE, NTMP_AM_ENTRY_ID);
 
-       err = netc_xmit_ntmp_cmd(user, &cbd);
+       err = netc_xmit_ntmp_cmd(user, &cbd, &swcbd);
        if (err)
                dev_err(user->dev, "Failed to update RSST entry, err: %pe\n",
                        ERR_PTR(err));
 
-       ntmp_free_data_mem(&data);
-
        return err;
 }
 EXPORT_SYMBOL_GPL(ntmp_rsst_update_entry);
 
 int ntmp_rsst_query_entry(struct ntmp_user *user, u32 *table, int count)
 {
-       struct ntmp_dma_buf data = {.dev = user->dev};
        struct ntmp_req_by_eid *req;
+       struct netc_swcbd swcbd;
        union netc_cbd cbd;
        int err, i;
        u8 *group;
@@ -425,21 +438,21 @@ int ntmp_rsst_query_entry(struct ntmp_user *user, u32 
*table, int count)
                /* HW only takes in a full 64 entry table */
                return -EINVAL;
 
-       data.size = NTMP_ENTRY_ID_SIZE + RSST_STSE_DATA_SIZE(count) +
-                   RSST_CFGE_DATA_SIZE(count);
-       err = ntmp_alloc_data_mem(&data, (void **)&req);
+       swcbd.size = NTMP_ENTRY_ID_SIZE + RSST_STSE_DATA_SIZE(count) +
+                    RSST_CFGE_DATA_SIZE(count);
+       err = ntmp_alloc_data_mem(user->dev, &swcbd, (void **)&req);
        if (err)
                return err;
 
        /* Set the request data buffer */
        ntmp_fill_crd_eid(req, user->tbl.rsst_ver, 0, 0, 0);
-       ntmp_fill_request_hdr(&cbd, data.dma, NTMP_LEN(sizeof(*req), data.size),
+       ntmp_fill_request_hdr(&cbd, swcbd.dma, NTMP_LEN(sizeof(*req), 
swcbd.size),
                              NTMP_RSST_ID, NTMP_CMD_QUERY, NTMP_AM_ENTRY_ID);
-       err = netc_xmit_ntmp_cmd(user, &cbd);
+       err = netc_xmit_ntmp_cmd(user, &cbd, &swcbd);
        if (err) {
                dev_err(user->dev, "Failed to query RSST entry, err: %pe\n",
                        ERR_PTR(err));
-               goto end;
+               return err;
        }
 
        group = (u8 *)req;
@@ -447,10 +460,7 @@ int ntmp_rsst_query_entry(struct ntmp_user *user, u32 
*table, int count)
        for (i = 0; i < count; i++)
                table[i] = group[i];
 
-end:
-       ntmp_free_data_mem(&data);
-
-       return err;
+       return 0;
 }
 EXPORT_SYMBOL_GPL(ntmp_rsst_query_entry);
 
diff --git a/drivers/net/ethernet/freescale/enetc/ntmp_private.h 
b/drivers/net/ethernet/freescale/enetc/ntmp_private.h
index 7a53db8740db..5ae6f8b92700 100644
--- a/drivers/net/ethernet/freescale/enetc/ntmp_private.h
+++ b/drivers/net/ethernet/freescale/enetc/ntmp_private.h
@@ -13,6 +13,7 @@
 #define NTMP_EID_REQ_LEN       8
 #define NETC_CBDR_BD_NUM       256
 #define NETC_CBDRCIR_INDEX     GENMASK(9, 0)
+#define NETC_CBDR_CLEAN_WORK   16
 
 union netc_cbd {
        struct {
@@ -55,13 +56,6 @@ union netc_cbd {
        } resp_hdr; /* NTMP Response Message Header Format */
 };
 
-struct ntmp_dma_buf {
-       struct device *dev;
-       size_t size;
-       void *buf;
-       dma_addr_t dma;
-};
-
 struct ntmp_cmn_req_data {
        __le16 update_act;
        u8 dbg_opt;
diff --git a/include/linux/fsl/ntmp.h b/include/linux/fsl/ntmp.h
index 916dc4fe7de3..83a449b4d6ec 100644
--- a/include/linux/fsl/ntmp.h
+++ b/include/linux/fsl/ntmp.h
@@ -31,6 +31,12 @@ struct netc_tbl_vers {
        u8 rsst_ver;
 };
 
+struct netc_swcbd {
+       void *buf;
+       dma_addr_t dma;
+       size_t size;
+};
+
 struct netc_cbdr {
        struct device *dev;
        struct netc_cbdr_regs regs;
@@ -44,9 +50,10 @@ struct netc_cbdr {
        void *addr_base_align;
        dma_addr_t dma_base;
        dma_addr_t dma_base_align;
+       struct netc_swcbd *swcbd;
 
        /* Serialize the order of command BD ring */
-       spinlock_t ring_lock;
+       struct mutex ring_lock;
 };
 
 struct ntmp_user {
-- 
2.34.1


Reply via email to