The css_driver's main purpose is to create/destroy the mdev and relay the
shutdown, irq, sch_event, and chp_event css_driver ops to the single
created vfio_device, if it exists.

Reframe the boundary where the css_driver domain switches to the vfio
domain by using rcu to read and refcount the vfio_device out of the sch's
drvdata. The mdev probe/remove will manage the drvdata of the parent.

The vfio core code refcounting thus guarantees that when a css_driver
callback is running the vfio_device is registered, simplifying the
understanding of the whole lifecycle.

Finally the vfio_ccw_private is allocated/freed during probe/remove of the
mdev like any other vfio_device struct.

Signed-off-by: Jason Gunthorpe <j...@nvidia.com>
---
 drivers/s390/cio/vfio_ccw_drv.c     | 67 ++++++++++++++---------------
 drivers/s390/cio/vfio_ccw_ops.c     | 40 +++++++----------
 drivers/s390/cio/vfio_ccw_private.h | 23 +++++++++-
 3 files changed, 69 insertions(+), 61 deletions(-)

diff --git a/drivers/s390/cio/vfio_ccw_drv.c b/drivers/s390/cio/vfio_ccw_drv.c
index 0e2edd96567a09..b86da53443bfd7 100644
--- a/drivers/s390/cio/vfio_ccw_drv.c
+++ b/drivers/s390/cio/vfio_ccw_drv.c
@@ -86,13 +86,19 @@ static void vfio_ccw_crw_todo(struct work_struct *work)
  */
 static void vfio_ccw_sch_irq(struct subchannel *sch)
 {
-       struct vfio_ccw_private *private = dev_get_drvdata(&sch->dev);
+       struct vfio_ccw_private *private = vfio_ccw_get_priv(sch);
+
+       /* IRQ should not be delivered after the mdev is destroyed */
+       if (WARN_ON(!private))
+               return;
 
        inc_irq_stat(IRQIO_CIO);
        vfio_ccw_fsm_event(private, VFIO_CCW_EVENT_INTERRUPT);
+       vfio_device_put(&private->vdev);
 }
 
-static struct vfio_ccw_private *vfio_ccw_alloc_private(struct subchannel *sch)
+struct vfio_ccw_private *vfio_ccw_alloc_private(struct mdev_device *mdev,
+                                               struct subchannel *sch)
 {
        struct vfio_ccw_private *private;
 
@@ -100,6 +106,8 @@ static struct vfio_ccw_private 
*vfio_ccw_alloc_private(struct subchannel *sch)
        if (!private)
                return ERR_PTR(-ENOMEM);
 
+       vfio_init_group_dev(&private->vdev, &mdev->dev,
+                           &vfio_ccw_dev_ops);
        private->sch = sch;
        mutex_init(&private->io_mutex);
        private->state = VFIO_CCW_STATE_CLOSED;
@@ -145,11 +153,12 @@ static struct vfio_ccw_private 
*vfio_ccw_alloc_private(struct subchannel *sch)
        kfree(private->cp.guest_cp);
 out_free_private:
        mutex_destroy(&private->io_mutex);
+       vfio_uninit_group_dev(&private->vdev);
        kfree(private);
        return ERR_PTR(-ENOMEM);
 }
 
-static void vfio_ccw_free_private(struct vfio_ccw_private *private)
+void vfio_ccw_free_private(struct vfio_ccw_private *private)
 {
        struct vfio_ccw_crw *crw, *temp;
 
@@ -164,14 +173,14 @@ static void vfio_ccw_free_private(struct vfio_ccw_private 
*private)
        kmem_cache_free(vfio_ccw_io_region, private->io_region);
        kfree(private->cp.guest_cp);
        mutex_destroy(&private->io_mutex);
-       kfree(private);
+       vfio_uninit_group_dev(&private->vdev);
+       kfree_rcu(private, rcu);
 }
 
 static int vfio_ccw_sch_probe(struct subchannel *sch)
 {
        struct pmcw *pmcw = &sch->schib.pmcw;
-       struct vfio_ccw_private *private;
-       int ret = -ENOMEM;
+       int ret;
 
        if (pmcw->qf) {
                dev_warn(&sch->dev, "vfio: ccw: does not support QDIO: %s\n",
@@ -179,15 +188,9 @@ static int vfio_ccw_sch_probe(struct subchannel *sch)
                return -ENODEV;
        }
 
-       private = vfio_ccw_alloc_private(sch);
-       if (IS_ERR(private))
-               return PTR_ERR(private);
-
-       dev_set_drvdata(&sch->dev, private);
-
-       ret = vfio_ccw_mdev_reg(sch);
+       ret = mdev_register_device(&sch->dev, &vfio_ccw_mdev_ops);
        if (ret)
-               goto out_free;
+               return ret;
 
        if (dev_get_uevent_suppress(&sch->dev)) {
                dev_set_uevent_suppress(&sch->dev, 0);
@@ -198,22 +201,11 @@ static int vfio_ccw_sch_probe(struct subchannel *sch)
                           sch->schid.cssid, sch->schid.ssid,
                           sch->schid.sch_no);
        return 0;
-
-out_free:
-       dev_set_drvdata(&sch->dev, NULL);
-       vfio_ccw_free_private(private);
-       return ret;
 }
 
 static int vfio_ccw_sch_remove(struct subchannel *sch)
 {
-       struct vfio_ccw_private *private = dev_get_drvdata(&sch->dev);
-
-       vfio_ccw_mdev_unreg(sch);
-
-       dev_set_drvdata(&sch->dev, NULL);
-
-       vfio_ccw_free_private(private);
+       mdev_unregister_device(&sch->dev);
 
        VFIO_CCW_MSG_EVENT(4, "unbound from subchannel %x.%x.%04x\n",
                           sch->schid.cssid, sch->schid.ssid,
@@ -223,10 +215,14 @@ static int vfio_ccw_sch_remove(struct subchannel *sch)
 
 static void vfio_ccw_sch_shutdown(struct subchannel *sch)
 {
-       struct vfio_ccw_private *private = dev_get_drvdata(&sch->dev);
+       struct vfio_ccw_private *private = vfio_ccw_get_priv(sch);
+
+       if (!private)
+               return;
 
        vfio_ccw_fsm_event(private, VFIO_CCW_EVENT_CLOSE);
        vfio_ccw_fsm_event(private, VFIO_CCW_EVENT_BROKEN);
+       vfio_device_put(&private->vdev);
 }
 
 /**
@@ -241,14 +237,14 @@ static void vfio_ccw_sch_shutdown(struct subchannel *sch)
  */
 static int vfio_ccw_sch_event(struct subchannel *sch, int process)
 {
-       struct vfio_ccw_private *private = dev_get_drvdata(&sch->dev);
+       struct vfio_ccw_private *private = vfio_ccw_get_priv(sch);
        unsigned long flags;
        int rc = -EAGAIN;
 
-       spin_lock_irqsave(sch->lock, flags);
-       if (!device_is_registered(&sch->dev))
-               goto out_unlock;
+       if (!private)
+               return -EAGAIN;
 
+       spin_lock_irqsave(sch->lock, flags);
        if (work_pending(&sch->todo_work))
                goto out_unlock;
 
@@ -261,7 +257,7 @@ static int vfio_ccw_sch_event(struct subchannel *sch, int 
process)
 
 out_unlock:
        spin_unlock_irqrestore(sch->lock, flags);
-
+       vfio_device_put(&private->vdev);
        return rc;
 }
 
@@ -295,7 +291,7 @@ static void vfio_ccw_queue_crw(struct vfio_ccw_private 
*private,
 static int vfio_ccw_chp_event(struct subchannel *sch,
                              struct chp_link *link, int event)
 {
-       struct vfio_ccw_private *private = dev_get_drvdata(&sch->dev);
+       struct vfio_ccw_private *private = vfio_ccw_get_priv(sch);
        int mask = chp_ssd_get_mask(&sch->ssd_info, link);
        int retry = 255;
 
@@ -308,8 +304,10 @@ static int vfio_ccw_chp_event(struct subchannel *sch,
                           sch->schid.ssid, sch->schid.sch_no,
                           mask, event);
 
-       if (cio_update_schib(sch))
+       if (cio_update_schib(sch)) {
+               vfio_device_put(&private->vdev);
                return -ENODEV;
+       }
 
        switch (event) {
        case CHP_VARY_OFF:
@@ -339,6 +337,7 @@ static int vfio_ccw_chp_event(struct subchannel *sch,
                break;
        }
 
+       vfio_device_put(&private->vdev);
        return 0;
 }
 
diff --git a/drivers/s390/cio/vfio_ccw_ops.c b/drivers/s390/cio/vfio_ccw_ops.c
index 23004e67c492f6..04a10f37d64225 100644
--- a/drivers/s390/cio/vfio_ccw_ops.c
+++ b/drivers/s390/cio/vfio_ccw_ops.c
@@ -17,8 +17,6 @@
 
 #include "vfio_ccw_private.h"
 
-static const struct vfio_device_ops vfio_ccw_dev_ops;
-
 static int vfio_ccw_mdev_reset(struct vfio_ccw_private *private)
 {
        /*
@@ -88,26 +86,27 @@ static struct attribute_group *mdev_type_groups[] = {
 
 static int vfio_ccw_mdev_probe(struct mdev_device *mdev)
 {
-       struct vfio_ccw_private *private = dev_get_drvdata(mdev->dev.parent);
+       struct subchannel *sch = to_subchannel(mdev->dev.parent);
+       struct vfio_ccw_private *private;
        int ret;
 
-       memset(&private->vdev, 0, sizeof(private->vdev));
-       vfio_init_group_dev(&private->vdev, &mdev->dev,
-                           &vfio_ccw_dev_ops);
+       private = vfio_ccw_alloc_private(mdev, sch);
+       if (IS_ERR(private))
+               return PTR_ERR(private);
 
        VFIO_CCW_MSG_EVENT(2, "mdev %s, sch %x.%x.%04x: create\n",
-                          dev_name(private->vdev.dev),
-                          private->sch->schid.cssid, private->sch->schid.ssid,
-                          private->sch->schid.sch_no);
+                          dev_name(private->vdev.dev), sch->schid.cssid,
+                          sch->schid.ssid, sch->schid.sch_no);
 
        ret = vfio_register_group_dev(&private->vdev);
        if (ret)
-               goto err_init;
+               goto err_alloc;
        dev_set_drvdata(&mdev->dev, private);
+       dev_set_drvdata(&sch->dev, private);
        return 0;
 
-err_init:
-       vfio_uninit_group_dev(&private->vdev);
+err_alloc:
+       vfio_ccw_free_private(private);
        return ret;
 }
 
@@ -120,8 +119,9 @@ static void vfio_ccw_mdev_remove(struct mdev_device *mdev)
                           private->sch->schid.cssid, private->sch->schid.ssid,
                           private->sch->schid.sch_no);
 
+       dev_set_drvdata(&private->sch->dev, NULL);
        vfio_unregister_group_dev(&private->vdev);
-       vfio_uninit_group_dev(&private->vdev);
+       vfio_ccw_free_private(private);
 }
 
 static int vfio_ccw_mdev_open_device(struct vfio_device *vdev)
@@ -595,7 +595,7 @@ static unsigned int vfio_ccw_get_available(struct mdev_type 
*mtype)
        return 1;
 }
 
-static const struct vfio_device_ops vfio_ccw_dev_ops = {
+const struct vfio_device_ops vfio_ccw_dev_ops = {
        .open_device = vfio_ccw_mdev_open_device,
        .close_device = vfio_ccw_mdev_close_device,
        .read = vfio_ccw_mdev_read,
@@ -615,19 +615,9 @@ struct mdev_driver vfio_ccw_mdev_driver = {
        .get_available = vfio_ccw_get_available,
 };
 
-static const struct mdev_parent_ops vfio_ccw_mdev_ops = {
+const struct mdev_parent_ops vfio_ccw_mdev_ops = {
        .owner                  = THIS_MODULE,
        .device_driver          = &vfio_ccw_mdev_driver,
        .device_api             = VFIO_DEVICE_API_CCW_STRING,
        .supported_type_groups  = mdev_type_groups,
 };
-
-int vfio_ccw_mdev_reg(struct subchannel *sch)
-{
-       return mdev_register_device(&sch->dev, &vfio_ccw_mdev_ops);
-}
-
-void vfio_ccw_mdev_unreg(struct subchannel *sch)
-{
-       mdev_unregister_device(&sch->dev);
-}
diff --git a/drivers/s390/cio/vfio_ccw_private.h 
b/drivers/s390/cio/vfio_ccw_private.h
index 67ee9c624393b0..852ff94fc107d6 100644
--- a/drivers/s390/cio/vfio_ccw_private.h
+++ b/drivers/s390/cio/vfio_ccw_private.h
@@ -24,6 +24,8 @@
 #include "css.h"
 #include "vfio_ccw_cp.h"
 
+struct mdev_device;
+
 #define VFIO_CCW_OFFSET_SHIFT   10
 #define VFIO_CCW_OFFSET_TO_INDEX(off)  (off >> VFIO_CCW_OFFSET_SHIFT)
 #define VFIO_CCW_INDEX_TO_OFFSET(index)        ((u64)(index) << 
VFIO_CCW_OFFSET_SHIFT)
@@ -69,6 +71,7 @@ struct vfio_ccw_crw {
 /**
  * struct vfio_ccw_private
  * @vdev: Embedded VFIO device
+ * @rcu: head for kfree_rcu()
  * @sch: pointer to the subchannel
  * @state: internal state of the device
  * @completion: synchronization helper of the I/O completion
@@ -91,6 +94,7 @@ struct vfio_ccw_crw {
  */
 struct vfio_ccw_private {
        struct vfio_device vdev;
+       struct rcu_head rcu;
        struct subchannel       *sch;
        int                     state;
        struct completion       *completion;
@@ -115,10 +119,25 @@ struct vfio_ccw_private {
        struct work_struct      crw_work;
 } __aligned(8);
 
-extern int vfio_ccw_mdev_reg(struct subchannel *sch);
-extern void vfio_ccw_mdev_unreg(struct subchannel *sch);
+struct vfio_ccw_private *vfio_ccw_alloc_private(struct mdev_device *mdev,
+                                               struct subchannel *sch);
+void vfio_ccw_free_private(struct vfio_ccw_private *private);
 
 extern struct mdev_driver vfio_ccw_mdev_driver;
+extern const struct mdev_parent_ops vfio_ccw_mdev_ops;
+extern const struct vfio_device_ops vfio_ccw_dev_ops;
+
+static inline struct vfio_ccw_private *vfio_ccw_get_priv(struct subchannel 
*sch)
+{
+       struct vfio_ccw_private *private;
+
+       rcu_read_lock();
+       private = dev_get_drvdata(&sch->dev);
+       if (private && !vfio_device_try_get(&private->vdev))
+               private = NULL;
+       rcu_read_unlock();
+       return private;
+}
 
 /*
  * States of the device statemachine.
-- 
2.33.0

Reply via email to