1. this commit hardens dma unmap for indirect
2. the subsequent commit uses the struct extra to record whether the
   buffers need to be unmapped or not. So we need a struct extra for
   every desc, whatever it is indirect or not.

Signed-off-by: Xuan Zhuo <xuanz...@linux.alibaba.com>
---
 drivers/virtio/virtio_ring.c | 122 ++++++++++++++++-------------------
 1 file changed, 57 insertions(+), 65 deletions(-)

diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index df8eb0521aa0..1e037da542b9 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -67,9 +67,16 @@
 #define LAST_ADD_TIME_INVALID(vq)
 #endif
 
+struct vring_desc_extra {
+       dma_addr_t addr;                /* Descriptor DMA addr. */
+       u32 len;                        /* Descriptor length. */
+       u16 flags;                      /* Descriptor flags. */
+       u16 next;                       /* The next desc state in a list. */
+};
+
 struct vring_desc_state_split {
        void *data;                     /* Data for callback. */
-       struct vring_desc *indir_desc;  /* Indirect descriptor, if any. */
+       struct vring_desc_extra *indir; /* Indirect descriptor, if any. */
 };
 
 struct vring_desc_state_packed {
@@ -79,13 +86,6 @@ struct vring_desc_state_packed {
        u16 last;                       /* The last desc state in a list. */
 };
 
-struct vring_desc_extra {
-       dma_addr_t addr;                /* Descriptor DMA addr. */
-       u32 len;                        /* Descriptor length. */
-       u16 flags;                      /* Descriptor flags. */
-       u16 next;                       /* The next desc state in a list. */
-};
-
 struct vring_virtqueue_split {
        /* Actual memory layout for this queue. */
        struct vring vring;
@@ -440,38 +440,20 @@ static void virtqueue_init(struct vring_virtqueue *vq, 
u32 num)
  * Split ring specific functions - *_split().
  */
 
-static void vring_unmap_one_split_indirect(const struct vring_virtqueue *vq,
-                                          const struct vring_desc *desc)
-{
-       u16 flags;
-
-       if (!vring_need_unmap_buffer(vq))
-               return;
-
-       flags = virtio16_to_cpu(vq->vq.vdev, desc->flags);
-
-       dma_unmap_page(vring_dma_dev(vq),
-                      virtio64_to_cpu(vq->vq.vdev, desc->addr),
-                      virtio32_to_cpu(vq->vq.vdev, desc->len),
-                      (flags & VRING_DESC_F_WRITE) ?
-                      DMA_FROM_DEVICE : DMA_TO_DEVICE);
-}
-
 static unsigned int vring_unmap_one_split(const struct vring_virtqueue *vq,
-                                         unsigned int i)
+                                         struct vring_desc_extra *extra)
 {
-       struct vring_desc_extra *extra = vq->split.desc_extra;
        u16 flags;
 
-       flags = extra[i].flags;
+       flags = extra->flags;
 
        if (flags & VRING_DESC_F_INDIRECT) {
                if (!vq->use_dma_api)
                        goto out;
 
                dma_unmap_single(vring_dma_dev(vq),
-                                extra[i].addr,
-                                extra[i].len,
+                                extra->addr,
+                                extra->len,
                                 (flags & VRING_DESC_F_WRITE) ?
                                 DMA_FROM_DEVICE : DMA_TO_DEVICE);
        } else {
@@ -479,20 +461,22 @@ static unsigned int vring_unmap_one_split(const struct 
vring_virtqueue *vq,
                        goto out;
 
                dma_unmap_page(vring_dma_dev(vq),
-                              extra[i].addr,
-                              extra[i].len,
+                              extra->addr,
+                              extra->len,
                               (flags & VRING_DESC_F_WRITE) ?
                               DMA_FROM_DEVICE : DMA_TO_DEVICE);
        }
 
 out:
-       return extra[i].next;
+       return extra->next;
 }
 
 static struct vring_desc *alloc_indirect_split(struct virtqueue *_vq,
                                               unsigned int total_sg,
+                                              struct vring_desc_extra **pextra,
                                               gfp_t gfp)
 {
+       struct vring_desc_extra *extra;
        struct vring_desc *desc;
        unsigned int i;
 
@@ -503,40 +487,45 @@ static struct vring_desc *alloc_indirect_split(struct 
virtqueue *_vq,
         */
        gfp &= ~__GFP_HIGHMEM;
 
-       desc = kmalloc_array(total_sg, sizeof(struct vring_desc), gfp);
-       if (!desc)
+       extra = kmalloc_array(total_sg, sizeof(*desc) + sizeof(*extra), gfp);
+       if (!extra)
                return NULL;
 
-       for (i = 0; i < total_sg; i++)
+       desc = (struct vring_desc *)&extra[total_sg];
+
+       for (i = 0; i < total_sg; i++) {
                desc[i].next = cpu_to_virtio16(_vq->vdev, i + 1);
+               extra[i].next = i + 1;
+       }
+
+       *pextra = extra;
+
        return desc;
 }
 
 static inline unsigned int virtqueue_add_desc_split(struct virtqueue *vq,
                                                    struct vring_desc *desc,
+                                                   struct vring_desc_extra 
*extra,
                                                    unsigned int i,
                                                    dma_addr_t addr,
                                                    unsigned int len,
                                                    u16 flags,
                                                    bool indirect)
 {
-       struct vring_virtqueue *vring = to_vvq(vq);
-       struct vring_desc_extra *extra = vring->split.desc_extra;
        u16 next;
 
        desc[i].flags = cpu_to_virtio16(vq->vdev, flags);
        desc[i].addr = cpu_to_virtio64(vq->vdev, addr);
        desc[i].len = cpu_to_virtio32(vq->vdev, len);
 
-       if (!indirect) {
-               next = extra[i].next;
-               desc[i].next = cpu_to_virtio16(vq->vdev, next);
+       extra[i].addr = addr;
+       extra[i].len = len;
+       extra[i].flags = flags;
 
-               extra[i].addr = addr;
-               extra[i].len = len;
-               extra[i].flags = flags;
-       } else
-               next = virtio16_to_cpu(vq->vdev, desc[i].next);
+       next = extra[i].next;
+
+       if (!indirect)
+               desc[i].next = cpu_to_virtio16(vq->vdev, next);
 
        return next;
 }
@@ -551,6 +540,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
                                      gfp_t gfp)
 {
        struct vring_virtqueue *vq = to_vvq(_vq);
+       struct vring_desc_extra *extra;
        struct scatterlist *sg;
        struct vring_desc *desc;
        unsigned int i, n, avail, descs_used, prev, err_idx;
@@ -574,7 +564,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
        head = vq->free_head;
 
        if (virtqueue_use_indirect(vq, total_sg))
-               desc = alloc_indirect_split(_vq, total_sg, gfp);
+               desc = alloc_indirect_split(_vq, total_sg, &extra, gfp);
        else {
                desc = NULL;
                WARN_ON_ONCE(total_sg > vq->split.vring.num && !vq->indirect);
@@ -589,6 +579,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
        } else {
                indirect = false;
                desc = vq->split.vring.desc;
+               extra = vq->split.desc_extra;
                i = head;
                descs_used = total_sg;
        }
@@ -618,7 +609,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
                        /* Note that we trust indirect descriptor
                         * table since it use stream DMA mapping.
                         */
-                       i = virtqueue_add_desc_split(_vq, desc, i, addr, 
sg->length,
+                       i = virtqueue_add_desc_split(_vq, desc, extra, i, addr, 
sg->length,
                                                     VRING_DESC_F_NEXT,
                                                     indirect);
                }
@@ -634,7 +625,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
                        /* Note that we trust indirect descriptor
                         * table since it use stream DMA mapping.
                         */
-                       i = virtqueue_add_desc_split(_vq, desc, i, addr,
+                       i = virtqueue_add_desc_split(_vq, desc, extra, i, addr,
                                                     sg->length,
                                                     VRING_DESC_F_NEXT |
                                                     VRING_DESC_F_WRITE,
@@ -660,6 +651,7 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
                }
 
                virtqueue_add_desc_split(_vq, vq->split.vring.desc,
+                                        vq->split.desc_extra,
                                         head, addr,
                                         total_sg * sizeof(struct vring_desc),
                                         VRING_DESC_F_INDIRECT,
@@ -678,9 +670,9 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
        /* Store token and indirect buffer state. */
        vq->split.desc_state[head].data = data;
        if (indirect)
-               vq->split.desc_state[head].indir_desc = desc;
+               vq->split.desc_state[head].indir = extra;
        else
-               vq->split.desc_state[head].indir_desc = ctx;
+               vq->split.desc_state[head].indir = ctx;
 
        /* Put entry in available array (but don't update avail->idx until they
         * do sync). */
@@ -716,11 +708,8 @@ static inline int virtqueue_add_split(struct virtqueue 
*_vq,
        for (n = 0; n < total_sg; n++) {
                if (i == err_idx)
                        break;
-               if (indirect) {
-                       vring_unmap_one_split_indirect(vq, &desc[i]);
-                       i = virtio16_to_cpu(_vq->vdev, desc[i].next);
-               } else
-                       i = vring_unmap_one_split(vq, i);
+
+               i = vring_unmap_one_split(vq, &extra[i]);
        }
 
 free_indirect:
@@ -765,22 +754,25 @@ static bool virtqueue_kick_prepare_split(struct virtqueue 
*_vq)
 static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
                             void **ctx)
 {
+       struct vring_desc_extra *extra;
        unsigned int i, j;
        __virtio16 nextflag = cpu_to_virtio16(vq->vq.vdev, VRING_DESC_F_NEXT);
 
        /* Clear data ptr. */
        vq->split.desc_state[head].data = NULL;
 
+       extra = vq->split.desc_extra;
+
        /* Put back on free list: unmap first-level descriptors and find end */
        i = head;
 
        while (vq->split.vring.desc[i].flags & nextflag) {
-               vring_unmap_one_split(vq, i);
+               vring_unmap_one_split(vq, &extra[i]);
                i = vq->split.desc_extra[i].next;
                vq->vq.num_free++;
        }
 
-       vring_unmap_one_split(vq, i);
+       vring_unmap_one_split(vq, &extra[i]);
        vq->split.desc_extra[i].next = vq->free_head;
        vq->free_head = head;
 
@@ -788,12 +780,12 @@ static void detach_buf_split(struct vring_virtqueue *vq, 
unsigned int head,
        vq->vq.num_free++;
 
        if (vq->indirect) {
-               struct vring_desc *indir_desc =
-                               vq->split.desc_state[head].indir_desc;
                u32 len;
 
+               extra = vq->split.desc_state[head].indir;
+
                /* Free the indirect table, if any, now that it's unmapped. */
-               if (!indir_desc)
+               if (!extra)
                        return;
 
                len = vq->split.desc_extra[head].len;
@@ -804,13 +796,13 @@ static void detach_buf_split(struct vring_virtqueue *vq, 
unsigned int head,
 
                if (vring_need_unmap_buffer(vq)) {
                        for (j = 0; j < len / sizeof(struct vring_desc); j++)
-                               vring_unmap_one_split_indirect(vq, 
&indir_desc[j]);
+                               vring_unmap_one_split(vq, &extra[j]);
                }
 
-               kfree(indir_desc);
-               vq->split.desc_state[head].indir_desc = NULL;
+               kfree(extra);
+               vq->split.desc_state[head].indir = NULL;
        } else if (ctx) {
-               *ctx = vq->split.desc_state[head].indir_desc;
+               *ctx = vq->split.desc_state[head].indir;
        }
 }
 
-- 
2.32.0.3.g01195cf9f


Reply via email to