diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h index 3ce96ce5fe3e..04f2f33607e9 100644 --- a/include/linux/memcontrol.h +++ b/include/linux/memcontrol.h @@ -1756,6 +1756,42 @@ static inline void count_objcg_event(struct obj_cgroup *objcg, rcu_read_unlock(); } +/** + * get_mem_cgroup_from_obj - get a memcg associated with passed kernel object. + * @p: pointer to object from which memcg should be extracted. It can be NULL. + * + * Retrieves the memory group into which the memory of the pointed kernel + * object is accounted. If memcg is found, its reference is taken. + * If a passed kernel object is uncharged, or if proper memcg cannot be found, + * as well as if mem_cgroup is disabled, NULL is returned. + * + * Return: valid memcg pointer with taken reference or NULL. + */ +static inline struct mem_cgroup *get_mem_cgroup_from_obj(void *p) +{ + struct mem_cgroup *memcg; + + rcu_read_lock(); + do { + memcg = mem_cgroup_from_obj(p); + } while (memcg && !css_tryget(&memcg->css)); + rcu_read_unlock(); + return memcg; +} + +/** + * mem_cgroup_or_root - always returns a pointer to a valid memory cgroup. + * @memcg: pointer to a valid memory cgroup or NULL. + * + * If passed argument is not NULL, returns it without any additional checks + * and changes. Otherwise, root_mem_cgroup is returned. + * + * NOTE: root_mem_cgroup can be NULL during early boot. + */ +static inline struct mem_cgroup *mem_cgroup_or_root(struct mem_cgroup *memcg) +{ + return memcg ? memcg : root_mem_cgroup; +} #else static inline bool mem_cgroup_kmem_disabled(void) { @@ -1799,7 +1835,7 @@ static inline int memcg_kmem_id(struct mem_cgroup *memcg) static inline struct mem_cgroup *mem_cgroup_from_obj(void *p) { - return NULL; + return NULL; } static inline struct mem_cgroup *mem_cgroup_from_slab_obj(void *p) @@ -1812,6 +1848,15 @@ static inline void count_objcg_event(struct obj_cgroup *objcg, { } +static inline struct mem_cgroup *get_mem_cgroup_from_obj(void *p) +{ + return NULL; +} + +static inline struct mem_cgroup *mem_cgroup_or_root(struct mem_cgroup *memcg) +{ + return NULL; +} #endif /* CONFIG_MEMCG_KMEM */ #if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP) diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c index 0ec2f5906a27..6b9f19122ec1 100644 --- a/net/core/net_namespace.c +++ b/net/core/net_namespace.c @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -1143,7 +1144,13 @@ static int __register_pernet_operations(struct list_head *list, * setup_net() and cleanup_net() are not possible. */ for_each_net(net) { + struct mem_cgroup *old, *memcg; + + memcg = mem_cgroup_or_root(get_mem_cgroup_from_obj(net)); + old = set_active_memcg(memcg); error = ops_init(ops, net); + set_active_memcg(old); + mem_cgroup_put(memcg); if (error) goto out_undo; list_add_tail(&net->exit_list, &net_exit_list);