After device_initialize(), the lifetime of the embedded struct device is
expected to be managed through the device core reference counting.

In __devm_create_dev_dax(), several failure paths after
device_initialize() free dev_dax directly instead of releasing the
device reference with put_device(). This bypasses the normal device
lifetime rules and may leave the reference count of the embedded struct
device unbalanced, resulting in a refcount leak and potentially leading
to a use-after-free.

Fix this by assigning dev->type before device_initialize(), so the
release callback is available for put_device(), and use put_device() in
the post-initialization error paths. Keep dev_dax range cleanup explicit
in the error path.

Fixes: c2f3011ee697f ("device-dax: add an allocation interface for device-dax 
instances")
Cc: [email protected]
Signed-off-by: Guangshuo Li <[email protected]>
---
 drivers/dax/bus.c | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/drivers/dax/bus.c b/drivers/dax/bus.c
index fde29e0ad68b..8753115cd371 100644
--- a/drivers/dax/bus.c
+++ b/drivers/dax/bus.c
@@ -1453,6 +1453,7 @@ static struct dev_dax *__devm_create_dev_dax(struct 
dev_dax_data *data)
        }
 
        dev = &dev_dax->dev;
+       dev->type = &dev_dax_type;
        device_initialize(dev);
        dev_set_name(dev, "dax%d.%d", dax_region->id, dev_dax->id);
 
@@ -1499,7 +1500,6 @@ static struct dev_dax *__devm_create_dev_dax(struct 
dev_dax_data *data)
        dev->devt = inode->i_rdev;
        dev->bus = &dax_bus_type;
        dev->parent = parent;
-       dev->type = &dev_dax_type;
 
        rc = device_add(dev);
        if (rc) {
@@ -1523,14 +1523,21 @@ static struct dev_dax *__devm_create_dev_dax(struct 
dev_dax_data *data)
 
 err_alloc_dax:
        kfree(dev_dax->pgmap);
+       dev_dax->pgmap = NULL;
+
 err_pgmap:
        free_dev_dax_ranges(dev_dax);
+       put_device(dev);
+       return ERR_PTR(rc);
+
 err_range:
-       free_dev_dax_id(dev_dax);
+       put_device(dev);
+       return ERR_PTR(rc);
+
 err_id:
        kfree(dev_dax);
-
        return ERR_PTR(rc);
+
 }
 
 struct dev_dax *devm_create_dev_dax(struct dev_dax_data *data)
-- 
2.43.0


Reply via email to