On Tue, Oct 29, 2024 at 04:09:41PM -0300, Jason Gunthorpe wrote:
> On Fri, Oct 25, 2024 at 04:50:34PM -0700, Nicolin Chen wrote:
> > @@ -497,17 +497,35 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd)
> >             goto out;
> >     }
> >  
> > -   hwpt = iommufd_get_hwpt_nested(ucmd, cmd->hwpt_id);
> > -   if (IS_ERR(hwpt)) {
> > -           rc = PTR_ERR(hwpt);
> > +   pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY);
> > +   if (IS_ERR(pt_obj)) {
> > +           rc = PTR_ERR(pt_obj);
> >             goto out;
> >     }
> > +   if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) {
> > +           struct iommufd_hw_pagetable *hwpt =
> > +                   container_of(pt_obj, struct iommufd_hw_pagetable, obj);
> > +
> > +           rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain,
> > +                                                         &data_array);
> > +   } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) {
> > +           struct iommufd_viommu *viommu =
> > +                   container_of(pt_obj, struct iommufd_viommu, obj);
> > +
> > +           if (!viommu->ops || !viommu->ops->cache_invalidate) {
> > +                   rc = -EOPNOTSUPP;
> > +                   goto out_put_pt;
> > +           }
> > +           rc = viommu->ops->cache_invalidate(viommu, &data_array);
> > +   } else {
> > +           rc = -EINVAL;
> > +           goto out_put_pt;
> > +   }
> 
> Given the test in iommufd_viommu_alloc_hwpt_nested() is:
> 
>       if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
>                        (!viommu->ops->cache_invalidate &&
>                         !hwpt->domain->ops->cache_invalidate_user)))
>                         {
> 
> We will crash if the user passes a viommu allocated domain as
> IOMMUFD_OBJ_HWPT_NESTED since the above doesn't check it.

Ah, that was missed.

> I suggest we put the required if (ops..) -EOPNOTSUPP above and remove
> the ops->cache_invalidate checks from both WARN_ONs.

Ack. I will add hwpt->domain->ops check:
---------------------------------------------------------------------
        if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) {
                struct iommufd_hw_pagetable *hwpt =
                        container_of(pt_obj, struct iommufd_hw_pagetable, obj);
        
                if (!hwpt->domain->ops ||
                    !hwpt->domain->ops->cache_invalidate_user) {
                        rc = -EOPNOTSUPP;
                        goto out_put_pt;
                }
                rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain,
                                                              &data_array);
        } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) {
                struct iommufd_viommu *viommu =
                        container_of(pt_obj, struct iommufd_viommu, obj);
        
                if (!viommu->ops || !viommu->ops->cache_invalidate) {
                        rc = -EOPNOTSUPP;
                        goto out_put_pt;
                }
                rc = viommu->ops->cache_invalidate(viommu, &data_array);
        } else {
---------------------------------------------------------------------

Thanks
Nicolin

Reply via email to