Hi Eric,

On 2020/11/18 19:21, Eric Auger wrote:
> From: Jean-Philippe Brucker <jean-phili...@linaro.org>
> 
> When handling faults from the event or PRI queue, we need to find the
> struct device associated to a SID. Add a rb_tree to keep track of SIDs.
> 
> Signed-off-by: Jean-Philippe Brucker <jean-phili...@linaro.org>
[...]

>  }
>  
> +static int arm_smmu_insert_master(struct arm_smmu_device *smmu,
> +                               struct arm_smmu_master *master)
> +{
> +     int i;
> +     int ret = 0;
> +     struct arm_smmu_stream *new_stream, *cur_stream;
> +     struct rb_node **new_node, *parent_node = NULL;
> +     struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(master->dev);
> +
> +     master->streams = kcalloc(fwspec->num_ids,
> +                               sizeof(struct arm_smmu_stream), GFP_KERNEL);
> +     if (!master->streams)
> +             return -ENOMEM;
> +     master->num_streams = fwspec->num_ids;
This is not roll-backed when fail.

> +
> +     mutex_lock(&smmu->streams_mutex);
> +     for (i = 0; i < fwspec->num_ids && !ret; i++) {
Check ret at here, makes it hard to decide the start index of rollback.

If we fail at here, then start index is (i-2).
If we fail in the loop, then start index is (i-1).

> +             u32 sid = fwspec->ids[i];
> +
> +             new_stream = &master->streams[i];
> +             new_stream->id = sid;
> +             new_stream->master = master;
> +
> +             /*
> +              * Check the SIDs are in range of the SMMU and our stream table
> +              */
> +             if (!arm_smmu_sid_in_range(smmu, sid)) {
> +                     ret = -ERANGE;
> +                     break;
> +             }
> +
> +             /* Ensure l2 strtab is initialised */
> +             if (smmu->features & ARM_SMMU_FEAT_2_LVL_STRTAB) {
> +                     ret = arm_smmu_init_l2_strtab(smmu, sid);
> +                     if (ret)
> +                             break;
> +             }
> +
> +             /* Insert into SID tree */
> +             new_node = &(smmu->streams.rb_node);
> +             while (*new_node) {
> +                     cur_stream = rb_entry(*new_node, struct arm_smmu_stream,
> +                                           node);
> +                     parent_node = *new_node;
> +                     if (cur_stream->id > new_stream->id) {
> +                             new_node = &((*new_node)->rb_left);
> +                     } else if (cur_stream->id < new_stream->id) {
> +                             new_node = &((*new_node)->rb_right);
> +                     } else {
> +                             dev_warn(master->dev,
> +                                      "stream %u already in tree\n",
> +                                      cur_stream->id);
> +                             ret = -EINVAL;
> +                             break;
> +                     }
> +             }
> +
> +             if (!ret) {
> +                     rb_link_node(&new_stream->node, parent_node, new_node);
> +                     rb_insert_color(&new_stream->node, &smmu->streams);
> +             }
> +     }
> +
> +     if (ret) {
> +             for (; i > 0; i--)
should be (i >= 0)?
And the start index seems not correct.

> +                     rb_erase(&master->streams[i].node, &smmu->streams);
> +             kfree(master->streams);
> +     }
> +     mutex_unlock(&smmu->streams_mutex);
> +
> +     return ret;
> +}
> +
> +static void arm_smmu_remove_master(struct arm_smmu_master *master)
> +{
> +     int i;
> +     struct arm_smmu_device *smmu = master->smmu;
> +     struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(master->dev);
> +
> +     if (!smmu || !master->streams)
> +             return;
> +
> +     mutex_lock(&smmu->streams_mutex);
> +     for (i = 0; i < fwspec->num_ids; i++)
> +             rb_erase(&master->streams[i].node, &smmu->streams);
> +     mutex_unlock(&smmu->streams_mutex);
> +
> +     kfree(master->streams);
> +}
> +
>  static struct iommu_ops arm_smmu_ops;
>  
>  static struct iommu_device *arm_smmu_probe_device(struct device *dev)
>  {
> -     int i, ret;
> +     int ret;
>       struct arm_smmu_device *smmu;
>       struct arm_smmu_master *master;
>       struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
> @@ -2331,27 +2447,12 @@ static struct iommu_device 
> *arm_smmu_probe_device(struct device *dev)
>  
>       master->dev = dev;
>       master->smmu = smmu;
> -     master->sids = fwspec->ids;
> -     master->num_sids = fwspec->num_ids;
>       INIT_LIST_HEAD(&master->bonds);
>       dev_iommu_priv_set(dev, master);
>  
> -     /* Check the SIDs are in range of the SMMU and our stream table */
> -     for (i = 0; i < master->num_sids; i++) {
> -             u32 sid = master->sids[i];
> -
> -             if (!arm_smmu_sid_in_range(smmu, sid)) {
> -                     ret = -ERANGE;
> -                     goto err_free_master;
> -             }
> -
> -             /* Ensure l2 strtab is initialised */
> -             if (smmu->features & ARM_SMMU_FEAT_2_LVL_STRTAB) {
> -                     ret = arm_smmu_init_l2_strtab(smmu, sid);
> -                     if (ret)
> -                             goto err_free_master;
> -             }
> -     }
> +     ret = arm_smmu_insert_master(smmu, master);
> +     if (ret)
> +             goto err_free_master;
>  
>       master->ssid_bits = min(smmu->ssid_bits, fwspec->num_pasid_bits);
>  
> @@ -2389,6 +2490,7 @@ static void arm_smmu_release_device(struct device *dev)
>       WARN_ON(arm_smmu_master_sva_enabled(master));
>       arm_smmu_detach_dev(master);
>       arm_smmu_disable_pasid(master);
> +     arm_smmu_remove_master(master);
>       kfree(master);

Thanks,
Keqian
_______________________________________________
iommu mailing list
iommu@lists.linux-foundation.org
https://lists.linuxfoundation.org/mailman/listinfo/iommu

Reply via email to