The virtio spec requires the driver not to kick the device before
DRIVER_OK is set.  init_vqs() primes the stats virtqueue with a buffer
and kicks the device before virtio_device_ready() is called in
virtballoon_probe(), violating this requirement.

Further, if the device responds to the early kick by processing the
buffer before DRIVER_OK, stats_request() fires and queues
update_balloon_stats_work.  Should probe then fail and free vb, the work
runs against freed memory.

To fix, move buffer setup to after DRIVER_OK. Be careful to
disable update_balloon_stats_work while this is going on,
to make sure it does not race with the setup.

setup_vqs() warns but does not fail probe or restore if
virtqueue_add_outbuf() fails; the call never actually fails in these
contexts since the queue is freshly initialized and empty.

Testing: tested that stats still work after the change.

Fixes: 9564e138b1f6 ("virtio: Add memory statistics reporting to the balloon 
driver (V4)")
Reported-by: Sashiko:gemini-3.1-pro-preview
Cc: David Hildenbrand <[email protected]>
Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Michael S. Tsirkin <[email protected]>
---

changes from v1:
        check that work enable/disable is balanced
        explain how add buf never fails in probe/restore

 drivers/virtio/virtio_balloon.c | 50 +++++++++++++++++++++------------
 1 file changed, 32 insertions(+), 18 deletions(-)

diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index 088b3a0e6ce6..bc0a2b19ca7d 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -611,25 +611,8 @@ static int init_vqs(struct virtio_balloon *vb)
        vb->inflate_vq = vqs[VIRTIO_BALLOON_VQ_INFLATE];
        vb->deflate_vq = vqs[VIRTIO_BALLOON_VQ_DEFLATE];
        if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_STATS_VQ)) {
-               struct scatterlist sg;
-               unsigned int num_stats;
                vb->stats_vq = vqs[VIRTIO_BALLOON_VQ_STATS];
-
-               /*
-                * Prime this virtqueue with one buffer so the hypervisor can
-                * use it to signal us later (it can't be broken yet!).
-                */
-               num_stats = update_balloon_stats(vb);
-
-               sg_init_one(&sg, vb->stats, sizeof(vb->stats[0]) * num_stats);
-               err = virtqueue_add_outbuf(vb->stats_vq, &sg, 1, vb,
-                                          GFP_KERNEL);
-               if (err) {
-                       dev_warn(&vb->vdev->dev, "%s: add stat_vq failed\n",
-                                __func__);
-                       return err;
-               }
-               virtqueue_kick(vb->stats_vq);
+               disable_work(&vb->update_balloon_stats_work);
        }
 
        if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
@@ -916,6 +899,33 @@ static int virtio_balloon_register_shrinker(struct 
virtio_balloon *vb)
        return 0;
 }
 
+static void setup_vqs(struct virtio_balloon *vb)
+{
+       struct scatterlist sg;
+       unsigned int num_stats;
+       bool ret;
+
+       if (!virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_STATS_VQ))
+               return;
+
+       /*
+        * Prime this virtqueue with one buffer so the hypervisor can
+        * use it to signal us later (it can't be broken yet!).
+        */
+       num_stats = update_balloon_stats(vb);
+       sg_init_one(&sg, vb->stats, sizeof(vb->stats[0]) * num_stats);
+       if (virtqueue_add_outbuf(vb->stats_vq, &sg, 1, vb, GFP_KERNEL)) {
+               dev_warn(&vb->vdev->dev, "%s: add stat_vq failed\n", __func__);
+               return;
+       }
+       virtqueue_kick(vb->stats_vq);
+
+       ret = enable_and_queue_work(system_freezable_wq,
+                                   &vb->update_balloon_stats_work);
+       /* Make sure we balanced enable/disable, or we won't report stats. */
+       BUG_ON(!ret);
+}
+
 static int virtballoon_probe(struct virtio_device *vdev)
 {
        struct virtio_balloon *vb;
@@ -1059,6 +1069,8 @@ static int virtballoon_probe(struct virtio_device *vdev)
 
        virtio_device_ready(vdev);
 
+       setup_vqs(vb);
+
        if (towards_target(vb))
                virtballoon_changed(vdev);
        return 0;
@@ -1148,6 +1160,8 @@ static int virtballoon_restore(struct virtio_device *vdev)
 
        virtio_device_ready(vdev);
 
+       setup_vqs(vb);
+
        if (towards_target(vb))
                virtballoon_changed(vdev);
        update_balloon_size(vb);
-- 
MST


Reply via email to