diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index 2b471419e26c..bfe011760ed1 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -1361,6 +1361,41 @@ struct iommu_group *fsl_mc_device_group(struct device *dev) } EXPORT_SYMBOL_GPL(fsl_mc_device_group); +static int iommu_alloc_default_domain(struct device *dev, + struct iommu_group *group) +{ + struct iommu_domain *dom; + + if (group->default_domain) + return 0; + + dom = __iommu_domain_alloc(dev->bus, iommu_def_domain_type); + if (!dom && iommu_def_domain_type != IOMMU_DOMAIN_DMA) { + dom = __iommu_domain_alloc(dev->bus, IOMMU_DOMAIN_DMA); + if (dom) { + dev_warn(dev, + "failed to allocate default IOMMU domain of type %u; falling back to IOMMU_DOMAIN_DMA", + iommu_def_domain_type); + } + } + + if (!dom) + return -ENOMEM; + + group->default_domain = dom; + if (!group->domain) + group->domain = dom; + + if (!iommu_dma_strict) { + int attr = 1; + iommu_domain_set_attr(dom, + DOMAIN_ATTR_DMA_USE_FLUSH_QUEUE, + &attr); + } + + return 0; +} + /** * iommu_group_get_for_dev - Find or create the IOMMU group for a device * @dev: target device @@ -1393,40 +1428,21 @@ struct iommu_group *iommu_group_get_for_dev(struct device *dev) /* * Try to allocate a default domain - needs support from the - * IOMMU driver. + * IOMMU driver. There are still some drivers which don't support + * default domains, so the return value is not yet checked. */ - if (!group->default_domain) { - struct iommu_domain *dom; - - dom = __iommu_domain_alloc(dev->bus, iommu_def_domain_type); - if (!dom && iommu_def_domain_type != IOMMU_DOMAIN_DMA) { - dom = __iommu_domain_alloc(dev->bus, IOMMU_DOMAIN_DMA); - if (dom) { - dev_warn(dev, - "failed to allocate default IOMMU domain of type %u; falling back to IOMMU_DOMAIN_DMA", - iommu_def_domain_type); - } - } - - group->default_domain = dom; - if (!group->domain) - group->domain = dom; - - if (dom && !iommu_dma_strict) { - int attr = 1; - iommu_domain_set_attr(dom, - DOMAIN_ATTR_DMA_USE_FLUSH_QUEUE, - &attr); - } - } + iommu_alloc_default_domain(dev, group); ret = iommu_group_add_device(group, dev); - if (ret) { - iommu_group_put(group); - return ERR_PTR(ret); - } + if (ret) + goto out_put_group; return group; + +out_put_group: + iommu_group_put(group); + + return ERR_PTR(ret); } EXPORT_SYMBOL(iommu_group_get_for_dev);