diff --git a/arch/x86/include/asm/amd_iommu_types.h b/arch/x86/include/asm/amd_iommu_types.h index 5e8da56755dd..b150c74e0d48 100644 --- a/arch/x86/include/asm/amd_iommu_types.h +++ b/arch/x86/include/asm/amd_iommu_types.h @@ -200,6 +200,12 @@ (((address) | ((pagesize) - 1)) & \ (~(pagesize >> 1)) & PM_ADDR_MASK) +/* + * Takes a PTE value with mode=0x07 and returns the page size it maps + */ +#define PTE_PAGE_SIZE(pte) \ + (1ULL << (1 + ffz(((pte) | 0xfffULL)))) + #define IOMMU_PTE_P (1ULL << 0) #define IOMMU_PTE_TV (1ULL << 1) #define IOMMU_PTE_U (1ULL << 59) diff --git a/arch/x86/kernel/amd_iommu.c b/arch/x86/kernel/amd_iommu.c index 41700314f3e0..503d312f9d6f 100644 --- a/arch/x86/kernel/amd_iommu.c +++ b/arch/x86/kernel/amd_iommu.c @@ -776,28 +776,47 @@ static u64 *alloc_pte(struct protection_domain *domain, * This function checks if there is a PTE for a given dma address. If * there is one, it returns the pointer to it. */ -static u64 *fetch_pte(struct protection_domain *domain, - unsigned long address, int map_size) +static u64 *fetch_pte(struct protection_domain *domain, unsigned long address) { int level; u64 *pte; - level = domain->mode - 1; - pte = &domain->pt_root[PM_LEVEL_INDEX(level, address)]; + if (address > PM_LEVEL_SIZE(domain->mode)) + return NULL; - while (level > map_size) { + level = domain->mode - 1; + pte = &domain->pt_root[PM_LEVEL_INDEX(level, address)]; + + while (level > 0) { + + /* Not Present */ if (!IOMMU_PTE_PRESENT(*pte)) return NULL; + /* Large PTE */ + if (PM_PTE_LEVEL(*pte) == 0x07) { + unsigned long pte_mask, __pte; + + /* + * If we have a series of large PTEs, make + * sure to return a pointer to the first one. + */ + pte_mask = PTE_PAGE_SIZE(*pte); + pte_mask = ~((PAGE_SIZE_PTE_COUNT(pte_mask) << 3) - 1); + __pte = ((unsigned long)pte) & pte_mask; + + return (u64 *)__pte; + } + + /* No level skipping support yet */ + if (PM_PTE_LEVEL(*pte) != level) + return NULL; + level -= 1; + /* Walk to the next level */ pte = IOMMU_PTE_PAGE(*pte); pte = &pte[PM_LEVEL_INDEX(level, address)]; - - if ((PM_PTE_LEVEL(*pte) == 0) && level != map_size) { - pte = NULL; - break; - } } return pte; @@ -850,13 +869,48 @@ static int iommu_map_page(struct protection_domain *dom, return 0; } -static void iommu_unmap_page(struct protection_domain *dom, - unsigned long bus_addr, int map_size) +static unsigned long iommu_unmap_page(struct protection_domain *dom, + unsigned long bus_addr, + unsigned long page_size) { - u64 *pte = fetch_pte(dom, bus_addr, map_size); + unsigned long long unmap_size, unmapped; + u64 *pte; - if (pte) - *pte = 0; + BUG_ON(!is_power_of_2(page_size)); + + unmapped = 0; + + while (unmapped < page_size) { + + pte = fetch_pte(dom, bus_addr); + + if (!pte) { + /* + * No PTE for this address + * move forward in 4kb steps + */ + unmap_size = PAGE_SIZE; + } else if (PM_PTE_LEVEL(*pte) == 0) { + /* 4kb PTE found for this address */ + unmap_size = PAGE_SIZE; + *pte = 0ULL; + } else { + int count, i; + + /* Large PTE found which maps this address */ + unmap_size = PTE_PAGE_SIZE(*pte); + count = PAGE_SIZE_PTE_COUNT(unmap_size); + for (i = 0; i < count; i++) + pte[i] = 0ULL; + } + + bus_addr = (bus_addr & ~(unmap_size - 1)) + unmap_size; + unmapped += unmap_size; + } + + BUG_ON(!is_power_of_2(unmapped)); + + return unmapped; } /* @@ -1054,7 +1108,7 @@ static int alloc_new_range(struct dma_ops_domain *dma_dom, for (i = dma_dom->aperture[index]->offset; i < dma_dom->aperture_size; i += PAGE_SIZE) { - u64 *pte = fetch_pte(&dma_dom->domain, i, PM_MAP_4k); + u64 *pte = fetch_pte(&dma_dom->domain, i); if (!pte || !IOMMU_PTE_PRESENT(*pte)) continue; @@ -2491,7 +2545,7 @@ static void amd_iommu_unmap_range(struct iommu_domain *dom, iova &= PAGE_MASK; for (i = 0; i < npages; ++i) { - iommu_unmap_page(domain, iova, PM_MAP_4k); + iommu_unmap_page(domain, iova, PAGE_SIZE); iova += PAGE_SIZE; } @@ -2506,7 +2560,7 @@ static phys_addr_t amd_iommu_iova_to_phys(struct iommu_domain *dom, phys_addr_t paddr; u64 *pte; - pte = fetch_pte(domain, iova, PM_MAP_4k); + pte = fetch_pte(domain, iova); if (!pte || !IOMMU_PTE_PRESENT(*pte)) return 0;