>  
> +typedef void *tdx_vm_state_guard_t;
> +
> +static tdx_vm_state_guard_t tdx_acquire_vm_state_locks(struct kvm *kvm)
> +{
> +     int r;
> +
> +     mutex_lock(&kvm->lock);
> +
> +     if (kvm->created_vcpus != atomic_read(&kvm->online_vcpus)) {
> +             r = -EBUSY;
> +             goto out_err;
> +     }
> +
> +     r = kvm_lock_all_vcpus(kvm);
> +     if (r)
> +             goto out_err;
> +
> +     /*
> +      * Note the unintuitive ordering!  vcpu->mutex must be taken outside
> +      * kvm->slots_lock!
> +      */
> +     mutex_lock(&kvm->slots_lock);
> +     return kvm;
> +
> +out_err:
> +     mutex_unlock(&kvm->lock);
> +     return ERR_PTR(r);
> +}
> +
> +static void tdx_release_vm_state_locks(struct kvm *kvm)
> +{
> +     mutex_unlock(&kvm->slots_lock);
> +     kvm_unlock_all_vcpus(kvm);
> +     mutex_unlock(&kvm->lock);
> +}
> +
> +DEFINE_CLASS(tdx_vm_state_guard, tdx_vm_state_guard_t,
> +          if (!IS_ERR(_T)) tdx_release_vm_state_locks(_T),
> +          tdx_acquire_vm_state_locks(kvm), struct kvm *kvm);
> +
>  static int tdx_td_init(struct kvm *kvm, struct kvm_tdx_cmd *cmd)
>  {
>       struct kvm_tdx_init_vm __user *user_data = u64_to_user_ptr(cmd->data);
> @@ -2644,6 +2684,10 @@ static int tdx_td_init(struct kvm *kvm, struct 
> kvm_tdx_cmd *cmd)
>       BUILD_BUG_ON(sizeof(*init_vm) != 256 + sizeof_field(struct 
> kvm_tdx_init_vm, cpuid));
>       BUILD_BUG_ON(sizeof(struct td_params) != 1024);
>  
> +     CLASS(tdx_vm_state_guard, guard)(kvm);
> +     if (IS_ERR(guard))
> +             return PTR_ERR(guard);
> +
>       if (kvm_tdx->state != TD_STATE_UNINITIALIZED)
>               return -EINVAL;
>  
> @@ -2743,7 +2787,9 @@ static int tdx_td_finalize(struct kvm *kvm, struct 
> kvm_tdx_cmd *cmd)
>  {
>       struct kvm_tdx *kvm_tdx = to_kvm_tdx(kvm);
>  
> -     guard(mutex)(&kvm->slots_lock);
> +     CLASS(tdx_vm_state_guard, guard)(kvm);
> +     if (IS_ERR(guard))
> +             return PTR_ERR(guard);
>  

Since you are changing both tdx_td_init() and tdx_td_finalize(), maybe
just changing tdx_vm_ioctl() instead (like tdx_vcpu_unlocked_ioctl())?  
This is not required for tdx_get_capabilities() but there's no harm to do
it too AFACIT.

Reply via email to