RMDA/odp: Consolidate umem_odp initialization

This is done in two different places, consolidate all the post-allocation
initialization into a single function.

Link: https://lore.kernel.org/r/20190819111710.18440-5-leon@kernel.org
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
This commit is contained in:
Jason Gunthorpe 2019-08-19 14:17:02 +03:00
parent fd7dbf035e
commit 22d79c9a91

View File

@ -171,23 +171,6 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
.invalidate_range_end = ib_umem_notifier_invalidate_range_end,
};
static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
down_write(&per_mm->umem_rwsem);
/*
* Note that the representation of the intervals in the interval tree
* considers the ending point as contained in the interval, while the
* function ib_umem_end returns the first address which is not
* contained in the umem.
*/
umem_odp->interval_tree.start = ib_umem_start(umem_odp);
umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
interval_tree_insert(&umem_odp->interval_tree, &per_mm->umem_tree);
up_write(&per_mm->umem_rwsem);
}
static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
@ -237,33 +220,23 @@ out_pid:
return ERR_PTR(ret);
}
static int get_per_mm(struct ib_umem_odp *umem_odp)
static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext *ctx = umem_odp->umem.context;
struct ib_ucontext_per_mm *per_mm;
lockdep_assert_held(&ctx->per_mm_list_lock);
/*
* Generally speaking we expect only one or two per_mm in this list,
* so no reason to optimize this search today.
*/
mutex_lock(&ctx->per_mm_list_lock);
list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
if (per_mm->mm == umem_odp->umem.owning_mm)
goto found;
return per_mm;
}
per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
if (IS_ERR(per_mm)) {
mutex_unlock(&ctx->per_mm_list_lock);
return PTR_ERR(per_mm);
}
found:
umem_odp->per_mm = per_mm;
per_mm->odp_mrs_count++;
mutex_unlock(&ctx->per_mm_list_lock);
return 0;
return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
}
static void free_per_mm(struct rcu_head *rcu)
@ -304,79 +277,114 @@ static void put_per_mm(struct ib_umem_odp *umem_odp)
mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
}
static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
struct ib_ucontext_per_mm *per_mm)
{
struct ib_ucontext *ctx = umem_odp->umem.context;
int ret;
umem_odp->umem.is_odp = 1;
if (!umem_odp->is_implicit_odp) {
size_t pages = ib_umem_odp_num_pages(umem_odp);
if (!pages)
return -EINVAL;
/*
* Note that the representation of the intervals in the
* interval tree considers the ending point as contained in
* the interval, while the function ib_umem_end returns the
* first address which is not contained in the umem.
*/
umem_odp->interval_tree.start = ib_umem_start(umem_odp);
umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
umem_odp->page_list = vzalloc(
array_size(sizeof(*umem_odp->page_list), pages));
if (!umem_odp->page_list)
return -ENOMEM;
umem_odp->dma_list =
vzalloc(array_size(sizeof(*umem_odp->dma_list), pages));
if (!umem_odp->dma_list) {
ret = -ENOMEM;
goto out_page_list;
}
}
mutex_lock(&ctx->per_mm_list_lock);
if (!per_mm) {
per_mm = get_per_mm(umem_odp);
if (IS_ERR(per_mm)) {
ret = PTR_ERR(per_mm);
goto out_unlock;
}
}
umem_odp->per_mm = per_mm;
per_mm->odp_mrs_count++;
mutex_unlock(&ctx->per_mm_list_lock);
mutex_init(&umem_odp->umem_mutex);
init_completion(&umem_odp->notifier_completion);
if (!umem_odp->is_implicit_odp) {
down_write(&per_mm->umem_rwsem);
interval_tree_insert(&umem_odp->interval_tree,
&per_mm->umem_tree);
up_write(&per_mm->umem_rwsem);
}
return 0;
out_unlock:
mutex_unlock(&ctx->per_mm_list_lock);
vfree(umem_odp->dma_list);
out_page_list:
vfree(umem_odp->page_list);
return ret;
}
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
unsigned long addr, size_t size)
{
struct ib_ucontext_per_mm *per_mm = root->per_mm;
struct ib_ucontext *ctx = per_mm->context;
/*
* Caller must ensure that root cannot be freed during the call to
* ib_alloc_odp_umem.
*/
struct ib_umem_odp *odp_data;
struct ib_umem *umem;
int pages = size >> PAGE_SHIFT;
int ret;
if (!size)
return ERR_PTR(-EINVAL);
odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
if (!odp_data)
return ERR_PTR(-ENOMEM);
umem = &odp_data->umem;
umem->context = ctx;
umem->context = root->umem.context;
umem->length = size;
umem->address = addr;
odp_data->page_shift = PAGE_SHIFT;
umem->writable = root->umem.writable;
umem->is_odp = 1;
odp_data->per_mm = per_mm;
umem->owning_mm = per_mm->mm;
mmgrab(umem->owning_mm);
umem->owning_mm = root->umem.owning_mm;
odp_data->page_shift = PAGE_SHIFT;
mutex_init(&odp_data->umem_mutex);
init_completion(&odp_data->notifier_completion);
odp_data->page_list =
vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
if (!odp_data->page_list) {
ret = -ENOMEM;
goto out_odp_data;
}
odp_data->dma_list =
vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
if (!odp_data->dma_list) {
ret = -ENOMEM;
goto out_page_list;
}
/*
* Caller must ensure that the umem_odp that the per_mm came from
* cannot be freed during the call to ib_alloc_odp_umem.
*/
mutex_lock(&ctx->per_mm_list_lock);
per_mm->odp_mrs_count++;
mutex_unlock(&ctx->per_mm_list_lock);
add_umem_to_per_mm(odp_data);
return odp_data;
out_page_list:
vfree(odp_data->page_list);
out_odp_data:
mmdrop(umem->owning_mm);
ret = ib_init_umem_odp(odp_data, root->per_mm);
if (ret) {
kfree(odp_data);
return ERR_PTR(ret);
}
mmgrab(umem->owning_mm);
return odp_data;
}
EXPORT_SYMBOL(ib_alloc_odp_umem);
int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{
struct ib_umem *umem = &umem_odp->umem;
/*
* NOTE: This must called in a process context where umem->owning_mm
* == current->mm
*/
struct mm_struct *mm = umem->owning_mm;
int ret_val;
struct mm_struct *mm = umem_odp->umem.owning_mm;
if (umem_odp->umem.address == 0 && umem_odp->umem.length == 0)
umem_odp->is_implicit_odp = 1;
@ -397,43 +405,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
up_read(&mm->mmap_sem);
}
mutex_init(&umem_odp->umem_mutex);
init_completion(&umem_odp->notifier_completion);
if (!umem_odp->is_implicit_odp) {
if (!ib_umem_odp_num_pages(umem_odp))
return -EINVAL;
umem_odp->page_list =
vzalloc(array_size(sizeof(*umem_odp->page_list),
ib_umem_odp_num_pages(umem_odp)));
if (!umem_odp->page_list)
return -ENOMEM;
umem_odp->dma_list =
vzalloc(array_size(sizeof(*umem_odp->dma_list),
ib_umem_odp_num_pages(umem_odp)));
if (!umem_odp->dma_list) {
ret_val = -ENOMEM;
goto out_page_list;
}
}
ret_val = get_per_mm(umem_odp);
if (ret_val)
goto out_dma_list;
if (!umem_odp->is_implicit_odp)
add_umem_to_per_mm(umem_odp);
return 0;
out_dma_list:
vfree(umem_odp->dma_list);
out_page_list:
vfree(umem_odp->page_list);
return ret_val;
return ib_init_umem_odp(umem_odp, NULL);
}
void ib_umem_odp_release(struct ib_umem_odp *umem_odp)