On Tue, Dec 5, 2023 at 11:47 AM Dragos Tatulea <dtatu...@nvidia.com> wrote:
>
> Deleting the old mr during mr update (.set_map) and then modifying the
> vqs with the new mr is not a good flow for firmware. The firmware
> expects that mkeys are deleted after there are no more vqs referencing
> them.
>
> Introduce reference counting for mrs to fix this. It is the only way to
> make sure that mkeys are not in use by vqs.
>
> An mr reference is taken when the mr is associated to the mr asid table
> and when the mr is linked to the vq on create/modify. The reference is
> released when the mkey is unlinked from the vq (trough modify/destroy)
> and from the mr asid table.
>
> To make things consistent, get rid of mlx5_vdpa_destroy_mr and use
> get/put semantics everywhere.
>
> Signed-off-by: Dragos Tatulea <dtatu...@nvidia.com>
> Reviewed-by: Gal Pressman <g...@nvidia.com>

Acked-by: Eugenio Pérez <epere...@redhat.com>

> ---
>  drivers/vdpa/mlx5/core/mlx5_vdpa.h |  8 +++--
>  drivers/vdpa/mlx5/core/mr.c        | 50 ++++++++++++++++++++----------
>  drivers/vdpa/mlx5/net/mlx5_vnet.c  | 45 ++++++++++++++++++++++-----
>  3 files changed, 78 insertions(+), 25 deletions(-)
>
> diff --git a/drivers/vdpa/mlx5/core/mlx5_vdpa.h 
> b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> index 84547d998bcf..1a0d27b6e09a 100644
> --- a/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> +++ b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> @@ -35,6 +35,8 @@ struct mlx5_vdpa_mr {
>         struct vhost_iotlb *iotlb;
>
>         bool user_mr;
> +
> +       refcount_t refcount;
>  };
>
>  struct mlx5_vdpa_resources {
> @@ -118,8 +120,10 @@ int mlx5_vdpa_destroy_mkey(struct mlx5_vdpa_dev *mvdev, 
> u32 mkey);
>  struct mlx5_vdpa_mr *mlx5_vdpa_create_mr(struct mlx5_vdpa_dev *mvdev,
>                                          struct vhost_iotlb *iotlb);
>  void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev);
> -void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
> -                         struct mlx5_vdpa_mr *mr);
> +void mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
> +                     struct mlx5_vdpa_mr *mr);
> +void mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
> +                     struct mlx5_vdpa_mr *mr);
>  void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
>                          struct mlx5_vdpa_mr *mr,
>                          unsigned int asid);
> diff --git a/drivers/vdpa/mlx5/core/mr.c b/drivers/vdpa/mlx5/core/mr.c
> index 2197c46e563a..c7dc8914354a 100644
> --- a/drivers/vdpa/mlx5/core/mr.c
> +++ b/drivers/vdpa/mlx5/core/mr.c
> @@ -498,32 +498,52 @@ static void destroy_user_mr(struct mlx5_vdpa_dev 
> *mvdev, struct mlx5_vdpa_mr *mr
>
>  static void _mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev, struct 
> mlx5_vdpa_mr *mr)
>  {
> +       if (WARN_ON(!mr))
> +               return;
> +
>         if (mr->user_mr)
>                 destroy_user_mr(mvdev, mr);
>         else
>                 destroy_dma_mr(mvdev, mr);
>
>         vhost_iotlb_free(mr->iotlb);
> +
> +       kfree(mr);
>  }
>
> -void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
> -                         struct mlx5_vdpa_mr *mr)
> +static void _mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
> +                             struct mlx5_vdpa_mr *mr)
>  {
>         if (!mr)
>                 return;
>
> +       if (refcount_dec_and_test(&mr->refcount))
> +               _mlx5_vdpa_destroy_mr(mvdev, mr);
> +}
> +
> +void mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
> +                     struct mlx5_vdpa_mr *mr)
> +{
>         mutex_lock(&mvdev->mr_mtx);
> +       _mlx5_vdpa_put_mr(mvdev, mr);
> +       mutex_unlock(&mvdev->mr_mtx);
> +}
>
> -       _mlx5_vdpa_destroy_mr(mvdev, mr);
> +static void _mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
> +                             struct mlx5_vdpa_mr *mr)
> +{
> +       if (!mr)
> +               return;
>
> -       for (int i = 0; i < MLX5_VDPA_NUM_AS; i++) {
> -               if (mvdev->mr[i] == mr)
> -                       mvdev->mr[i] = NULL;
> -       }
> +       refcount_inc(&mr->refcount);
> +}
>
> +void mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
> +                     struct mlx5_vdpa_mr *mr)
> +{
> +       mutex_lock(&mvdev->mr_mtx);
> +       _mlx5_vdpa_get_mr(mvdev, mr);
>         mutex_unlock(&mvdev->mr_mtx);
> -
> -       kfree(mr);
>  }
>
>  void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
> @@ -534,20 +554,16 @@ void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
>
>         mutex_lock(&mvdev->mr_mtx);
>
> +       _mlx5_vdpa_put_mr(mvdev, old_mr);
>         mvdev->mr[asid] = new_mr;
> -       if (old_mr) {
> -               _mlx5_vdpa_destroy_mr(mvdev, old_mr);
> -               kfree(old_mr);
> -       }
>
>         mutex_unlock(&mvdev->mr_mtx);
> -
>  }
>
>  void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev)
>  {
>         for (int i = 0; i < MLX5_VDPA_NUM_AS; i++)
> -               mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[i]);
> +               mlx5_vdpa_update_mr(mvdev, NULL, i);
>
>         prune_iotlb(mvdev->cvq.iotlb);
>  }
> @@ -607,6 +623,8 @@ struct mlx5_vdpa_mr *mlx5_vdpa_create_mr(struct 
> mlx5_vdpa_dev *mvdev,
>         if (err)
>                 goto out_err;
>
> +       refcount_set(&mr->refcount, 1);
> +
>         return mr;
>
>  out_err:
> @@ -651,7 +669,7 @@ int mlx5_vdpa_reset_mr(struct mlx5_vdpa_dev *mvdev, 
> unsigned int asid)
>         if (asid >= MLX5_VDPA_NUM_AS)
>                 return -EINVAL;
>
> -       mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[asid]);
> +       mlx5_vdpa_update_mr(mvdev, NULL, asid);
>
>         if (asid == 0 && MLX5_CAP_GEN(mvdev->mdev, umem_uid_0)) {
>                 if (mlx5_vdpa_create_dma_mr(mvdev))
> diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c 
> b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> index 6a21223d97a8..133cbb66dcfe 100644
> --- a/drivers/vdpa/mlx5/net/mlx5_vnet.c
> +++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> @@ -123,6 +123,9 @@ struct mlx5_vdpa_virtqueue {
>
>         u64 modified_fields;
>
> +       struct mlx5_vdpa_mr *vq_mr;
> +       struct mlx5_vdpa_mr *desc_mr;
> +
>         struct msi_map map;
>
>         /* keep last in the struct */
> @@ -946,6 +949,14 @@ static int create_virtqueue(struct mlx5_vdpa_net *ndev, 
> struct mlx5_vdpa_virtque
>         kfree(in);
>         mvq->virtq_id = MLX5_GET(general_obj_out_cmd_hdr, out, obj_id);
>
> +       mlx5_vdpa_get_mr(mvdev, vq_mr);
> +       mvq->vq_mr = vq_mr;
> +
> +       if (vq_desc_mr && MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, 
> desc_group_mkey_supported)) {
> +               mlx5_vdpa_get_mr(mvdev, vq_desc_mr);
> +               mvq->desc_mr = vq_desc_mr;
> +       }
> +
>         return 0;
>
>  err_cmd:
> @@ -972,6 +983,12 @@ static void destroy_virtqueue(struct mlx5_vdpa_net 
> *ndev, struct mlx5_vdpa_virtq
>         }
>         mvq->fw_state = MLX5_VIRTIO_NET_Q_OBJECT_NONE;
>         umems_destroy(ndev, mvq);
> +
> +       mlx5_vdpa_put_mr(&ndev->mvdev, mvq->vq_mr);
> +       mvq->vq_mr = NULL;
> +
> +       mlx5_vdpa_put_mr(&ndev->mvdev, mvq->desc_mr);
> +       mvq->desc_mr = NULL;
>  }
>
>  static u32 get_rqpn(struct mlx5_vdpa_virtqueue *mvq, bool fw)
> @@ -1207,6 +1224,8 @@ static int modify_virtqueue(struct mlx5_vdpa_net *ndev,
>         int inlen = MLX5_ST_SZ_BYTES(modify_virtio_net_q_in);
>         u32 out[MLX5_ST_SZ_DW(modify_virtio_net_q_out)] = {};
>         struct mlx5_vdpa_dev *mvdev = &ndev->mvdev;
> +       struct mlx5_vdpa_mr *desc_mr = NULL;
> +       struct mlx5_vdpa_mr *vq_mr = NULL;
>         bool state_change = false;
>         void *obj_context;
>         void *cmd_hdr;
> @@ -1257,19 +1276,19 @@ static int modify_virtqueue(struct mlx5_vdpa_net 
> *ndev,
>                 MLX5_SET(virtio_net_q_object, obj_context, hw_used_index, 
> mvq->used_idx);
>
>         if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_VIRTIO_Q_MKEY) {
> -               struct mlx5_vdpa_mr *mr = 
> mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP]];
> +               vq_mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP]];
>
> -               if (mr)
> -                       MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, mr->mkey);
> +               if (vq_mr)
> +                       MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, 
> vq_mr->mkey);
>                 else
>                         mvq->modified_fields &= 
> ~MLX5_VIRTQ_MODIFY_MASK_VIRTIO_Q_MKEY;
>         }
>
>         if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_DESC_GROUP_MKEY) {
> -               struct mlx5_vdpa_mr *mr = 
> mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_DESC_GROUP]];
> +               desc_mr = 
> mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_DESC_GROUP]];
>
> -               if (mr && MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, 
> desc_group_mkey_supported))
> -                       MLX5_SET(virtio_q, vq_ctx, desc_group_mkey, mr->mkey);
> +               if (desc_mr && MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, 
> desc_group_mkey_supported))
> +                       MLX5_SET(virtio_q, vq_ctx, desc_group_mkey, 
> desc_mr->mkey);
>                 else
>                         mvq->modified_fields &= 
> ~MLX5_VIRTQ_MODIFY_MASK_DESC_GROUP_MKEY;
>         }
> @@ -1282,6 +1301,18 @@ static int modify_virtqueue(struct mlx5_vdpa_net *ndev,
>         if (state_change)
>                 mvq->fw_state = state;
>
> +       if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_VIRTIO_Q_MKEY) {
> +               mlx5_vdpa_put_mr(mvdev, mvq->vq_mr);
> +               mlx5_vdpa_get_mr(mvdev, vq_mr);
> +               mvq->vq_mr = vq_mr;
> +       }
> +
> +       if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_DESC_GROUP_MKEY) {
> +               mlx5_vdpa_put_mr(mvdev, mvq->desc_mr);
> +               mlx5_vdpa_get_mr(mvdev, desc_mr);
> +               mvq->desc_mr = desc_mr;
> +       }
> +
>         mvq->modified_fields = 0;
>
>  done:
> @@ -3095,7 +3126,7 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev, 
> struct vhost_iotlb *iotlb,
>         return mlx5_vdpa_update_cvq_iotlb(mvdev, iotlb, asid);
>
>  out_err:
> -       mlx5_vdpa_destroy_mr(mvdev, new_mr);
> +       mlx5_vdpa_put_mr(mvdev, new_mr);
>         return err;
>  }
>
> --
> 2.42.0
>


Reply via email to