The virtio spec requires the driver not to kick the device before
DRIVER_OK is set. init_vqs() primes the stats virtqueue with a buffer
and kicks the device before virtio_device_ready() is called in
virtballoon_probe(), violating this requirement.
Further, if the device responds to the early kick by processing the
buffer before DRIVER_OK, stats_request() fires and queues
update_balloon_stats_work. Should probe then fail and free vb, the work
runs against freed memory.
To fix, move buffer setup to after DRIVER_OK. Be careful to
disable update_balloon_stats_work while this is going on,
to make sure it does not race with the setup.
Testing: tested that stats still work after the change.
Fixes: 9564e138b1f6 ("virtio: Add memory statistics reporting to the balloon
driver (V4)")
Reported-by: Sashiko:gemini-3.1-pro-preview
Cc: Adam Litke <[email protected]>
Cc: David Hildenbrand <[email protected]>
Assisted-by: Claude:claude-sonnet-4-6
Signed-off-by: Michael S. Tsirkin <[email protected]>
---
Fixes a bug reported by Sashiko and pointed out to me by Andrew.
I guess I'll queue this myself.
drivers/virtio/virtio_balloon.c | 39 ++++++++++++++++++---------------
1 file changed, 21 insertions(+), 18 deletions(-)
diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index 088b3a0e6ce6..d4cd2dd388e9 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -611,25 +611,8 @@ static int init_vqs(struct virtio_balloon *vb)
vb->inflate_vq = vqs[VIRTIO_BALLOON_VQ_INFLATE];
vb->deflate_vq = vqs[VIRTIO_BALLOON_VQ_DEFLATE];
if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_STATS_VQ)) {
- struct scatterlist sg;
- unsigned int num_stats;
vb->stats_vq = vqs[VIRTIO_BALLOON_VQ_STATS];
-
- /*
- * Prime this virtqueue with one buffer so the hypervisor can
- * use it to signal us later (it can't be broken yet!).
- */
- num_stats = update_balloon_stats(vb);
-
- sg_init_one(&sg, vb->stats, sizeof(vb->stats[0]) * num_stats);
- err = virtqueue_add_outbuf(vb->stats_vq, &sg, 1, vb,
- GFP_KERNEL);
- if (err) {
- dev_warn(&vb->vdev->dev, "%s: add stat_vq failed\n",
- __func__);
- return err;
- }
- virtqueue_kick(vb->stats_vq);
+ disable_work(&vb->update_balloon_stats_work);
}
if (virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_FREE_PAGE_HINT))
@@ -916,6 +899,22 @@ static int virtio_balloon_register_shrinker(struct
virtio_balloon *vb)
return 0;
}
+static void setup_vqs(struct virtio_balloon *vb)
+{
+ struct scatterlist sg;
+ unsigned int num_stats;
+
+ if (!virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_STATS_VQ))
+ return;
+
+ num_stats = update_balloon_stats(vb);
+ sg_init_one(&sg, vb->stats, sizeof(vb->stats[0]) * num_stats);
+ if (!virtqueue_add_outbuf(vb->stats_vq, &sg, 1, vb, GFP_KERNEL))
+ virtqueue_kick(vb->stats_vq);
+ enable_and_queue_work(system_freezable_wq,
+ &vb->update_balloon_stats_work);
+}
+
static int virtballoon_probe(struct virtio_device *vdev)
{
struct virtio_balloon *vb;
@@ -1059,6 +1058,8 @@ static int virtballoon_probe(struct virtio_device *vdev)
virtio_device_ready(vdev);
+ setup_vqs(vb);
+
if (towards_target(vb))
virtballoon_changed(vdev);
return 0;
@@ -1148,6 +1149,8 @@ static int virtballoon_restore(struct virtio_device *vdev)
virtio_device_ready(vdev);
+ setup_vqs(vb);
+
if (towards_target(vb))
virtballoon_changed(vdev);
update_balloon_size(vb);
--
MST