There are several places in the code that need to get the pointers of
svm and sdev according to a pasid and device. Add a helper to achieve
this for code consolidation and readability.

Signed-off-by: Lu Baolu <baolu...@linux.intel.com>
---
 drivers/iommu/intel/svm.c | 121 +++++++++++++++++++++-----------------
 1 file changed, 68 insertions(+), 53 deletions(-)

diff --git a/drivers/iommu/intel/svm.c b/drivers/iommu/intel/svm.c
index 25dd74f27252..c23167877b2b 100644
--- a/drivers/iommu/intel/svm.c
+++ b/drivers/iommu/intel/svm.c
@@ -228,6 +228,50 @@ static LIST_HEAD(global_svm_list);
        list_for_each_entry((sdev), &(svm)->devs, list) \
                if ((d) != (sdev)->dev) {} else
 
+static int pasid_to_svm_sdev(struct device *dev, unsigned int pasid,
+                            struct intel_svm **rsvm,
+                            struct intel_svm_dev **rsdev)
+{
+       struct intel_svm_dev *d, *sdev = NULL;
+       struct intel_svm *svm;
+
+       /* The caller should hold the pasid_mutex lock */
+       if (WARN_ON(!mutex_is_locked(&pasid_mutex)))
+               return -EINVAL;
+
+       if (pasid == INVALID_IOASID || pasid >= PASID_MAX)
+               return -EINVAL;
+
+       svm = ioasid_find(NULL, pasid, NULL);
+       if (IS_ERR(svm))
+               return PTR_ERR(svm);
+
+       if (!svm)
+               goto out;
+
+       /*
+        * If we found svm for the PASID, there must be at least one device
+        * bond.
+        */
+       if (WARN_ON(list_empty(&svm->devs)))
+               return -EINVAL;
+
+       rcu_read_lock();
+       list_for_each_entry_rcu(d, &svm->devs, list) {
+               if (d->dev == dev) {
+                       sdev = d;
+                       break;
+               }
+       }
+       rcu_read_unlock();
+
+out:
+       *rsvm = svm;
+       *rsdev = sdev;
+
+       return 0;
+}
+
 int intel_svm_bind_gpasid(struct iommu_domain *domain, struct device *dev,
                          struct iommu_gpasid_bind_data *data)
 {
@@ -261,39 +305,27 @@ int intel_svm_bind_gpasid(struct iommu_domain *domain, 
struct device *dev,
        dmar_domain = to_dmar_domain(domain);
 
        mutex_lock(&pasid_mutex);
-       svm = ioasid_find(NULL, data->hpasid, NULL);
-       if (IS_ERR(svm)) {
-               ret = PTR_ERR(svm);
+       ret = pasid_to_svm_sdev(dev, data->hpasid, &svm, &sdev);
+       if (ret)
                goto out;
-       }
 
-       if (svm) {
+       if (sdev) {
                /*
-                * If we found svm for the PASID, there must be at
-                * least one device bond, otherwise svm should be freed.
+                * For devices with aux domains, we should allow
+                * multiple bind calls with the same PASID and pdev.
                 */
-               if (WARN_ON(list_empty(&svm->devs))) {
-                       ret = -EINVAL;
-                       goto out;
+               if (iommu_dev_feature_enabled(dev, IOMMU_DEV_FEAT_AUX)) {
+                       sdev->users++;
+               } else {
+                       dev_warn_ratelimited(dev,
+                                            "Already bound with PASID %u\n",
+                                            svm->pasid);
+                       ret = -EBUSY;
                }
+               goto out;
+       }
 
-               for_each_svm_dev(sdev, svm, dev) {
-                       /*
-                        * For devices with aux domains, we should allow
-                        * multiple bind calls with the same PASID and pdev.
-                        */
-                       if (iommu_dev_feature_enabled(dev,
-                                                     IOMMU_DEV_FEAT_AUX)) {
-                               sdev->users++;
-                       } else {
-                               dev_warn_ratelimited(dev,
-                                                    "Already bound with PASID 
%u\n",
-                                                    svm->pasid);
-                               ret = -EBUSY;
-                       }
-                       goto out;
-               }
-       } else {
+       if (!svm) {
                /* We come here when PASID has never been bond to a device. */
                svm = kzalloc(sizeof(*svm), GFP_KERNEL);
                if (!svm) {
@@ -376,25 +408,17 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
        struct intel_iommu *iommu = device_to_iommu(dev, NULL, NULL);
        struct intel_svm_dev *sdev;
        struct intel_svm *svm;
-       int ret = -EINVAL;
+       int ret;
 
        if (WARN_ON(!iommu))
                return -EINVAL;
 
        mutex_lock(&pasid_mutex);
-       svm = ioasid_find(NULL, pasid, NULL);
-       if (!svm) {
-               ret = -EINVAL;
-               goto out;
-       }
-
-       if (IS_ERR(svm)) {
-               ret = PTR_ERR(svm);
+       ret = pasid_to_svm_sdev(dev, pasid, &svm, &sdev);
+       if (ret)
                goto out;
-       }
 
-       for_each_svm_dev(sdev, svm, dev) {
-               ret = 0;
+       if (sdev) {
                if (iommu_dev_feature_enabled(dev, IOMMU_DEV_FEAT_AUX))
                        sdev->users--;
                if (!sdev->users) {
@@ -418,7 +442,6 @@ int intel_svm_unbind_gpasid(struct device *dev, int pasid)
                                kfree(svm);
                        }
                }
-               break;
        }
 out:
        mutex_unlock(&pasid_mutex);
@@ -596,7 +619,7 @@ intel_svm_bind_mm(struct device *dev, int flags, struct 
svm_dev_ops *ops,
        if (sd)
                *sd = sdev;
        ret = 0;
- out:
+out:
        return ret;
 }
 
@@ -612,17 +635,11 @@ static int intel_svm_unbind_mm(struct device *dev, int 
pasid)
        if (!iommu)
                goto out;
 
-       svm = ioasid_find(NULL, pasid, NULL);
-       if (!svm)
-               goto out;
-
-       if (IS_ERR(svm)) {
-               ret = PTR_ERR(svm);
+       ret = pasid_to_svm_sdev(dev, pasid, &svm, &sdev);
+       if (ret)
                goto out;
-       }
 
-       for_each_svm_dev(sdev, svm, dev) {
-               ret = 0;
+       if (sdev) {
                sdev->users--;
                if (!sdev->users) {
                        list_del_rcu(&sdev->list);
@@ -651,10 +668,8 @@ static int intel_svm_unbind_mm(struct device *dev, int 
pasid)
                                kfree(svm);
                        }
                }
-               break;
        }
- out:
-
+out:
        return ret;
 }
 
-- 
2.17.1

_______________________________________________
iommu mailing list
iommu@lists.linux-foundation.org
https://lists.linuxfoundation.org/mailman/listinfo/iommu

Reply via email to