We used to return descriptor head by vhost_get_vq_desc() to device and
pass it back to vhost_add_used() and its friends. This exposes the
internal used ring layout to device which makes it hard to be extended for
e.g packed ring layout.

So this patch tries to hide the used ring layout by

- letting vhost_get_vq_desc() return pointer to struct vring_used_elem
- accepting pointer to struct vring_used_elem in vhost_add_used() and
  vhost_add_used_and_signal()

This could help to hide used ring layout and make it easier to
implement packed ring on top.

Signed-off-by: Jason Wang <jasow...@redhat.com>
---
 drivers/vhost/net.c   | 46 +++++++++++++++++++++-----------------
 drivers/vhost/scsi.c  | 62 +++++++++++++++++++++++++++------------------------
 drivers/vhost/vhost.c | 52 +++++++++++++++++++++---------------------
 drivers/vhost/vhost.h |  9 +++++---
 drivers/vhost/vsock.c | 42 +++++++++++++++++-----------------
 5 files changed, 112 insertions(+), 99 deletions(-)

diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 762aa81..826489c 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -426,22 +426,24 @@ static int vhost_net_enable_vq(struct vhost_net *n,
 
 static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
                                    struct vhost_virtqueue *vq,
+                                   struct vring_used_elem *used_elem,
                                    struct iovec iov[], unsigned int iov_size,
                                    unsigned int *out_num, unsigned int *in_num)
 {
        unsigned long uninitialized_var(endtime);
-       int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
+       int r = vhost_get_vq_desc(vq, used_elem, vq->iov, ARRAY_SIZE(vq->iov),
                                  out_num, in_num, NULL, NULL);
 
-       if (r == vq->num && vq->busyloop_timeout) {
+       if (r == -ENOSPC && vq->busyloop_timeout) {
                preempt_disable();
                endtime = busy_clock() + vq->busyloop_timeout;
                while (vhost_can_busy_poll(vq->dev, endtime) &&
                       vhost_vq_avail_empty(vq->dev, vq))
                        cpu_relax();
                preempt_enable();
-               r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
-                                     out_num, in_num, NULL, NULL);
+               r = vhost_get_vq_desc(vq, used_elem, vq->iov,
+                                     ARRAY_SIZE(vq->iov), out_num, in_num,
+                                     NULL, NULL);
        }
 
        return r;
@@ -463,7 +465,6 @@ static void handle_tx(struct vhost_net *net)
        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
        struct vhost_virtqueue *vq = &nvq->vq;
        unsigned out, in;
-       int head;
        struct msghdr msg = {
                .msg_name = NULL,
                .msg_namelen = 0,
@@ -476,6 +477,7 @@ static void handle_tx(struct vhost_net *net)
        size_t hdr_size;
        struct socket *sock;
        struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
+       struct vring_used_elem used;
        bool zcopy, zcopy_used;
        int sent_pkts = 0;
 
@@ -499,20 +501,20 @@ static void handle_tx(struct vhost_net *net)
                        vhost_zerocopy_signal_used(net, vq);
 
 
-               head = vhost_net_tx_get_vq_desc(net, vq, vq->iov,
-                                               ARRAY_SIZE(vq->iov),
-                                               &out, &in);
-               /* On error, stop handling until the next kick. */
-               if (unlikely(head < 0))
-                       break;
+               err = vhost_net_tx_get_vq_desc(net, vq, &used, vq->iov,
+                                              ARRAY_SIZE(vq->iov),
+                                              &out, &in);
                /* Nothing new?  Wait for eventfd to tell us they refilled. */
-               if (head == vq->num) {
+               if (err == -ENOSPC) {
                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
                                vhost_disable_notify(&net->dev, vq);
                                continue;
                        }
                        break;
                }
+               /* On error, stop handling until the next kick. */
+               if (unlikely(err < 0))
+                       break;
                if (in) {
                        vq_err(vq, "Unexpected descriptor format for TX: "
                               "out %d, int %d\n", out, in);
@@ -540,7 +542,8 @@ static void handle_tx(struct vhost_net *net)
                        struct ubuf_info *ubuf;
                        ubuf = nvq->ubuf_info + nvq->upend_idx;
 
-                       vq->heads[nvq->upend_idx].id = cpu_to_vhost32(vq, head);
+                       vq->heads[nvq->upend_idx].id =
+                               cpu_to_vhost32(vq, used.id);
                        vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
                        ubuf->callback = vhost_zerocopy_callback;
                        ubuf->ctx = nvq->ubufs;
@@ -581,7 +584,7 @@ static void handle_tx(struct vhost_net *net)
                        pr_debug("Truncated TX packet: "
                                 " len %d != %zd\n", err, len);
                if (!zcopy_used)
-                       vhost_add_used_and_signal(&net->dev, vq, head, 0);
+                       vhost_add_used_and_signal(&net->dev, vq, &used, 0);
                else
                        vhost_zerocopy_signal_used(net, vq);
                vhost_net_tx_packet(net);
@@ -713,14 +716,12 @@ static void handle_rx(struct vhost_net *net)
        while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
                sock_len += sock_hlen;
                vhost_len = sock_len + vhost_hlen;
-               headcount = vhost_get_bufs(vq, vq->heads + nheads, vhost_len,
-                                          &in, vq_log, &log,
-                                          likely(mergeable) ? UIO_MAXIOV : 1);
-               /* On error, stop handling until the next kick. */
-               if (unlikely(headcount < 0))
-                       goto out;
+               err = vhost_get_bufs(vq, vq->heads + nheads, vhost_len,
+                                    &in, vq_log, &log,
+                                    likely(mergeable) ? UIO_MAXIOV : 1,
+                                    &headcount);
                /* OK, now we need to know about added descriptors. */
-               if (!headcount) {
+               if (err == -ENOSPC) {
                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
                                /* They have slipped one in as we were
                                 * doing that: check again. */
@@ -731,6 +732,9 @@ static void handle_rx(struct vhost_net *net)
                         * they refilled. */
                        goto out;
                }
+               /* On error, stop handling until the next kick. */
+               if (unlikely(err < 0))
+                       goto out;
                if (nvq->rx_ring)
                        msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
                /* On overrun, truncate and discard */
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 7ad5709..654c71f 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -67,7 +67,7 @@ struct vhost_scsi_inflight {
 
 struct vhost_scsi_cmd {
        /* Descriptor from vhost_get_vq_desc() for virt_queue segment */
-       int tvc_vq_desc;
+       struct vring_used_elem tvc_vq_used;
        /* virtio-scsi initiator task attribute */
        int tvc_task_attr;
        /* virtio-scsi response incoming iovecs */
@@ -441,8 +441,9 @@ vhost_scsi_do_evt_work(struct vhost_scsi *vs, struct 
vhost_scsi_evt *evt)
        struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
        struct virtio_scsi_event *event = &evt->event;
        struct virtio_scsi_event __user *eventp;
+       struct vring_used_elem used;
        unsigned out, in;
-       int head, ret;
+       int ret;
 
        if (!vq->private_data) {
                vs->vs_events_missed = true;
@@ -451,16 +452,16 @@ vhost_scsi_do_evt_work(struct vhost_scsi *vs, struct 
vhost_scsi_evt *evt)
 
 again:
        vhost_disable_notify(&vs->dev, vq);
-       head = vhost_get_vq_desc(vq, vq->iov,
+       ret = vhost_get_vq_desc(vq, &used, vq->iov,
                        ARRAY_SIZE(vq->iov), &out, &in,
                        NULL, NULL);
-       if (head < 0) {
+       if (ret == -ENOSPC) {
+               if (vhost_enable_notify(&vs->dev, vq))
+                       goto again;
                vs->vs_events_missed = true;
                return;
        }
-       if (head == vq->num) {
-               if (vhost_enable_notify(&vs->dev, vq))
-                       goto again;
+       if (ret < 0) {
                vs->vs_events_missed = true;
                return;
        }
@@ -480,7 +481,7 @@ vhost_scsi_do_evt_work(struct vhost_scsi *vs, struct 
vhost_scsi_evt *evt)
        eventp = vq->iov[out].iov_base;
        ret = __copy_to_user(eventp, event, sizeof(*event));
        if (!ret)
-               vhost_add_used_and_signal(&vs->dev, vq, head, 0);
+               vhost_add_used_and_signal(&vs->dev, vq, &used, 0);
        else
                vq_err(vq, "Faulted on vhost_scsi_send_event\n");
 }
@@ -541,7 +542,7 @@ static void vhost_scsi_complete_cmd_work(struct vhost_work 
*work)
                ret = copy_to_iter(&v_rsp, sizeof(v_rsp), &iov_iter);
                if (likely(ret == sizeof(v_rsp))) {
                        struct vhost_scsi_virtqueue *q;
-                       vhost_add_used(cmd->tvc_vq, cmd->tvc_vq_desc, 0);
+                       vhost_add_used(cmd->tvc_vq, &cmd->tvc_vq_used, 0);
                        q = container_of(cmd->tvc_vq, struct 
vhost_scsi_virtqueue, vq);
                        vq = q - vs->vqs;
                        __set_bit(vq, signal);
@@ -784,7 +785,7 @@ static void vhost_scsi_submission_work(struct work_struct 
*work)
 static void
 vhost_scsi_send_bad_target(struct vhost_scsi *vs,
                           struct vhost_virtqueue *vq,
-                          int head, unsigned out)
+                          struct vring_used_elem *used, unsigned out)
 {
        struct virtio_scsi_cmd_resp __user *resp;
        struct virtio_scsi_cmd_resp rsp;
@@ -795,7 +796,7 @@ vhost_scsi_send_bad_target(struct vhost_scsi *vs,
        resp = vq->iov[out].iov_base;
        ret = __copy_to_user(resp, &rsp, sizeof(rsp));
        if (!ret)
-               vhost_add_used_and_signal(&vs->dev, vq, head, 0);
+               vhost_add_used_and_signal(&vs->dev, vq, used, 0);
        else
                pr_err("Faulted on virtio_scsi_cmd_resp\n");
 }
@@ -807,11 +808,12 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
        struct virtio_scsi_cmd_req v_req;
        struct virtio_scsi_cmd_req_pi v_req_pi;
        struct vhost_scsi_cmd *cmd;
+       struct vring_used_elem used;
        struct iov_iter out_iter, in_iter, prot_iter, data_iter;
        u64 tag;
        u32 exp_data_len, data_direction;
        unsigned int out = 0, in = 0;
-       int head, ret, prot_bytes;
+       int ret, prot_bytes;
        size_t req_size, rsp_size = sizeof(struct virtio_scsi_cmd_resp);
        size_t out_size, in_size;
        u16 lun;
@@ -831,22 +833,22 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
        vhost_disable_notify(&vs->dev, vq);
 
        for (;;) {
-               head = vhost_get_vq_desc(vq, vq->iov,
-                                        ARRAY_SIZE(vq->iov), &out, &in,
-                                        NULL, NULL);
+               ret = vhost_get_vq_desc(vq, &used, vq->iov,
+                                       ARRAY_SIZE(vq->iov), &out, &in,
+                                       NULL, NULL);
                pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
-                        head, out, in);
-               /* On error, stop handling until the next kick. */
-               if (unlikely(head < 0))
-                       break;
+                        used.id, out, in);
                /* Nothing new?  Wait for eventfd to tell us they refilled. */
-               if (head == vq->num) {
+               if (ret == -ENOSPC) {
                        if (unlikely(vhost_enable_notify(&vs->dev, vq))) {
                                vhost_disable_notify(&vs->dev, vq);
                                continue;
                        }
                        break;
                }
+               /* On error, stop handling until the next kick. */
+               if (unlikely(ret < 0))
+                       break;
                /*
                 * Check for a sane response buffer so we can report early
                 * errors back to the guest.
@@ -891,20 +893,20 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
 
                if (unlikely(!copy_from_iter_full(req, req_size, &out_iter))) {
                        vq_err(vq, "Faulted on copy_from_iter\n");
-                       vhost_scsi_send_bad_target(vs, vq, head, out);
+                       vhost_scsi_send_bad_target(vs, vq, &used, out);
                        continue;
                }
                /* virtio-scsi spec requires byte 0 of the lun to be 1 */
                if (unlikely(*lunp != 1)) {
                        vq_err(vq, "Illegal virtio-scsi lun: %u\n", *lunp);
-                       vhost_scsi_send_bad_target(vs, vq, head, out);
+                       vhost_scsi_send_bad_target(vs, vq, &used, out);
                        continue;
                }
 
                tpg = READ_ONCE(vs_tpg[*target]);
                if (unlikely(!tpg)) {
                        /* Target does not exist, fail the request */
-                       vhost_scsi_send_bad_target(vs, vq, head, out);
+                       vhost_scsi_send_bad_target(vs, vq, &used, out);
                        continue;
                }
                /*
@@ -950,7 +952,8 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
                                if (data_direction != DMA_TO_DEVICE) {
                                        vq_err(vq, "Received non zero 
pi_bytesout,"
                                                " but wrong data_direction\n");
-                                       vhost_scsi_send_bad_target(vs, vq, 
head, out);
+                                       vhost_scsi_send_bad_target(vs, vq,
+                                                                  &used, out);
                                        continue;
                                }
                                prot_bytes = vhost32_to_cpu(vq, 
v_req_pi.pi_bytesout);
@@ -958,7 +961,8 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
                                if (data_direction != DMA_FROM_DEVICE) {
                                        vq_err(vq, "Received non zero 
pi_bytesin,"
                                                " but wrong data_direction\n");
-                                       vhost_scsi_send_bad_target(vs, vq, 
head, out);
+                                       vhost_scsi_send_bad_target(vs, vq,
+                                                                  &used, out);
                                        continue;
                                }
                                prot_bytes = vhost32_to_cpu(vq, 
v_req_pi.pi_bytesin);
@@ -996,7 +1000,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
                        vq_err(vq, "Received SCSI CDB with command_size: %d 
that"
                                " exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n",
                                scsi_command_size(cdb), 
VHOST_SCSI_MAX_CDB_SIZE);
-                       vhost_scsi_send_bad_target(vs, vq, head, out);
+                       vhost_scsi_send_bad_target(vs, vq, &used, out);
                        continue;
                }
                cmd = vhost_scsi_get_tag(vq, tpg, cdb, tag, lun, task_attr,
@@ -1005,7 +1009,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
                if (IS_ERR(cmd)) {
                        vq_err(vq, "vhost_scsi_get_tag failed %ld\n",
                               PTR_ERR(cmd));
-                       vhost_scsi_send_bad_target(vs, vq, head, out);
+                       vhost_scsi_send_bad_target(vs, vq, &used, out);
                        continue;
                }
                cmd->tvc_vhost = vs;
@@ -1025,7 +1029,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
                        if (unlikely(ret)) {
                                vq_err(vq, "Failed to map iov to sgl\n");
                                vhost_scsi_release_cmd(&cmd->tvc_se_cmd);
-                               vhost_scsi_send_bad_target(vs, vq, head, out);
+                               vhost_scsi_send_bad_target(vs, vq, &used, out);
                                continue;
                        }
                }
@@ -1034,7 +1038,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct 
vhost_virtqueue *vq)
                 * complete the virtio-scsi request in TCM callback context via
                 * vhost_scsi_queue_data_in() and vhost_scsi_queue_status()
                 */
-               cmd->tvc_vq_desc = head;
+               cmd->tvc_vq_used = used;
                /*
                 * Dispatch cmd descriptor for cmwq execution in process
                 * context provided by vhost_scsi_workqueue.  This also ensures
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 6b455f6..e069adc 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -1955,6 +1955,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
  * never a valid descriptor number) if none was found.  A negative code is
  * returned on error. */
 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
+                     struct vring_used_elem *used,
                      struct iovec iov[], unsigned int iov_size,
                      unsigned int *out_num, unsigned int *in_num,
                      struct vhost_log *log, unsigned int *log_num)
@@ -1987,7 +1988,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
                 * invalid.
                 */
                if (vq->avail_idx == last_avail_idx)
-                       return vq->num;
+                       return -ENOSPC;
 
                /* Only get avail ring entries after they have been
                 * exposed by guest.
@@ -2005,6 +2006,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
                return -EFAULT;
        }
 
+       used->id = ring_head;
        head = vhost16_to_cpu(vq, ring_head);
 
        /* If their number is silly, that's an error. */
@@ -2093,10 +2095,16 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
        /* Assume notifications from guest are disabled at this point,
         * if they aren't we would need to update avail_event index. */
        BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
-       return head;
+       return 0;
 }
 EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
 
+static void vhost_set_used_len(struct vhost_virtqueue *vq,
+                              struct vring_used_elem *used, int len)
+{
+       used->len = cpu_to_vhost32(vq, len);
+}
+
 /* This is a multi-buffer version of vhost_get_desc, that works if
  *     vq has read descriptors only.
  * @vq         - the relevant virtqueue
@@ -2113,13 +2121,13 @@ int vhost_get_bufs(struct vhost_virtqueue *vq,
                   unsigned *iovcount,
                   struct vhost_log *log,
                   unsigned *log_num,
-                  unsigned int quota)
+                  unsigned int quota,
+                  s16 *count)
 {
        unsigned int out, in;
        int seg = 0;
        int headcount = 0;
-       unsigned d;
-       int r, nlogs = 0;
+       int r = 0, nlogs = 0;
        /* len is always initialized before use since we are always called with
         * datalen > 0.
         */
@@ -2130,17 +2138,12 @@ int vhost_get_bufs(struct vhost_virtqueue *vq,
                        r = -ENOBUFS;
                        goto err;
                }
-               r = vhost_get_vq_desc(vq, vq->iov + seg,
+               r = vhost_get_vq_desc(vq, &heads[headcount], vq->iov + seg,
                                      ARRAY_SIZE(vq->iov) - seg, &out,
                                      &in, log, log_num);
                if (unlikely(r < 0))
                        goto err;
 
-               d = r;
-               if (d == vq->num) {
-                       r = 0;
-                       goto err;
-               }
                if (unlikely(out || in <= 0)) {
                        vq_err(vq, "unexpected descriptor format for RX: "
                                "out %d, in %d\n", out, in);
@@ -2151,24 +2154,26 @@ int vhost_get_bufs(struct vhost_virtqueue *vq,
                        nlogs += *log_num;
                        log += *log_num;
                }
-               heads[headcount].id = cpu_to_vhost32(vq, d);
+
                len = iov_length(vq->iov + seg, in);
-               heads[headcount].len = cpu_to_vhost32(vq, len);
+               vhost_set_used_len(vq, &heads[headcount], len);
                datalen -= len;
                ++headcount;
                seg += in;
        }
-       heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen);
+       vhost_set_used_len(vq, &heads[headcount - 1], len + datalen);
        *iovcount = seg;
        if (unlikely(log))
                *log_num = nlogs;
 
        /* Detect overrun */
        if (unlikely(datalen > 0)) {
-               r = UIO_MAXIOV + 1;
+               headcount = UIO_MAXIOV + 1;
                goto err;
        }
-       return headcount;
+
+       *count = headcount;
+       return 0;
 err:
        vhost_discard_vq_desc(vq, headcount);
        return r;
@@ -2184,14 +2189,11 @@ EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
 
 /* After we've used one of their buffers, we tell them about it.  We'll then
  * want to notify the guest, using eventfd. */
-int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
+int vhost_add_used(struct vhost_virtqueue *vq, struct vring_used_elem *used,
+                  int len)
 {
-       struct vring_used_elem heads = {
-               cpu_to_vhost32(vq, head),
-               cpu_to_vhost32(vq, len)
-       };
-
-       return vhost_add_used_n(vq, &heads, 1);
+       vhost_set_used_len(vq, used, len);
+       return vhost_add_used_n(vq, used, 1);
 }
 EXPORT_SYMBOL_GPL(vhost_add_used);
 
@@ -2324,9 +2326,9 @@ EXPORT_SYMBOL_GPL(vhost_signal);
 /* And here's the combo meal deal.  Supersize me! */
 void vhost_add_used_and_signal(struct vhost_dev *dev,
                               struct vhost_virtqueue *vq,
-                              unsigned int head, int len)
+                              struct vring_used_elem *used, int len)
 {
-       vhost_add_used(vq, head, len);
+       vhost_add_used(vq, used, len);
        vhost_signal(dev, vq);
 }
 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 52edd242..a7cc7e7 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -182,6 +182,7 @@ bool vhost_vq_access_ok(struct vhost_virtqueue *vq);
 bool vhost_log_access_ok(struct vhost_dev *);
 
 int vhost_get_vq_desc(struct vhost_virtqueue *,
+                     struct vring_used_elem *used_elem,
                      struct iovec iov[], unsigned int iov_count,
                      unsigned int *out_num, unsigned int *in_num,
                      struct vhost_log *log, unsigned int *log_num);
@@ -191,15 +192,17 @@ int vhost_get_bufs(struct vhost_virtqueue *vq,
                   unsigned *iovcount,
                   struct vhost_log *log,
                   unsigned *log_num,
-                  unsigned int quota);
+                  unsigned int quota,
+                  s16 *count);
 void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
 
 int vhost_vq_init_access(struct vhost_virtqueue *);
-int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
+int vhost_add_used(struct vhost_virtqueue *vq,
+                  struct vring_used_elem *elem, int len);
 int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
                     unsigned count);
 void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *,
-                              unsigned int id, int len);
+                              struct vring_used_elem *, int len);
 void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *,
                               struct vring_used_elem *heads, unsigned count);
 void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *);
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 34bc3ab..59a01cd 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -98,11 +98,12 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
 
        for (;;) {
                struct virtio_vsock_pkt *pkt;
+               struct vring_used_elem used;
                struct iov_iter iov_iter;
                unsigned out, in;
                size_t nbytes;
                size_t len;
-               int head;
+               int ret;
 
                spin_lock_bh(&vsock->send_pkt_list_lock);
                if (list_empty(&vsock->send_pkt_list)) {
@@ -116,16 +117,9 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                list_del_init(&pkt->list);
                spin_unlock_bh(&vsock->send_pkt_list_lock);
 
-               head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
-                                        &out, &in, NULL, NULL);
-               if (head < 0) {
-                       spin_lock_bh(&vsock->send_pkt_list_lock);
-                       list_add(&pkt->list, &vsock->send_pkt_list);
-                       spin_unlock_bh(&vsock->send_pkt_list_lock);
-                       break;
-               }
-
-               if (head == vq->num) {
+               ret = vhost_get_vq_desc(vq, &used, vq->iov, ARRAY_SIZE(vq->iov),
+                                       &out, &in, NULL, NULL);
+               if (ret == -ENOSPC) {
                        spin_lock_bh(&vsock->send_pkt_list_lock);
                        list_add(&pkt->list, &vsock->send_pkt_list);
                        spin_unlock_bh(&vsock->send_pkt_list_lock);
@@ -139,6 +133,12 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                        }
                        break;
                }
+               if (ret < 0) {
+                       spin_lock_bh(&vsock->send_pkt_list_lock);
+                       list_add(&pkt->list, &vsock->send_pkt_list);
+                       spin_unlock_bh(&vsock->send_pkt_list_lock);
+                       break;
+               }
 
                if (out) {
                        virtio_transport_free_pkt(pkt);
@@ -146,7 +146,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                        break;
                }
 
-               len = iov_length(&vq->iov[out], in);
+               len = vhost32_to_cpu(vq, used.len);
                iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
 
                nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
@@ -163,7 +163,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                        break;
                }
 
-               vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
+               vhost_add_used(vq, &used, sizeof(pkt->hdr) + pkt->len);
                added = true;
 
                if (pkt->reply) {
@@ -346,7 +346,8 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work 
*work)
        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
                                                 dev);
        struct virtio_vsock_pkt *pkt;
-       int head;
+       struct vring_used_elem used;
+       int ret;
        unsigned int out, in;
        bool added = false;
 
@@ -367,18 +368,17 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work 
*work)
                        goto no_more_replies;
                }
 
-               head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
-                                        &out, &in, NULL, NULL);
-               if (head < 0)
-                       break;
-
-               if (head == vq->num) {
+               ret = vhost_get_vq_desc(vq, &used, vq->iov, ARRAY_SIZE(vq->iov),
+                                       &out, &in, NULL, NULL);
+               if (ret == -ENOSPC) {
                        if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
                                vhost_disable_notify(&vsock->dev, vq);
                                continue;
                        }
                        break;
                }
+               if (ret < 0)
+                       break;
 
                pkt = vhost_vsock_alloc_pkt(vq, out, in);
                if (!pkt) {
@@ -397,7 +397,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work 
*work)
                else
                        virtio_transport_free_pkt(pkt);
 
-               vhost_add_used(vq, head, sizeof(pkt->hdr) + len);
+               vhost_add_used(vq, &used, sizeof(pkt->hdr) + len);
                added = true;
        }
 
-- 
2.7.4

Reply via email to