From: Hui Zhu <[email protected]>

To allow for more flexible attachment policies in nested cgroup
hierarchies, this patch introduces support for the
`BPF_F_ALLOW_OVERRIDE` flag for `memcg_bpf_ops`.

When a `memcg_bpf_ops` is attached to a cgroup with this flag, it
permits child cgroups to attach their own, different `memcg_bpf_ops`,
overriding the parent's inherited program. Without this flag,
attaching a BPF program to a cgroup that already has one (either
directly or via inheritance) will fail.

The implementation involves:
- Adding a `bpf_ops_flags` field to `struct mem_cgroup`.
- During registration (`bpf_memcg_ops_reg`), checking for existing
  programs and the `BPF_F_ALLOW_OVERRIDE` flag.
- During unregistration (`bpf_memcg_ops_unreg`), correctly restoring
  the parent's BPF program to the cgroup hierarchy.
- Ensuring flags are inherited by child cgroups during online events.

This change enables complex, multi-level policy enforcement where
different subtrees of the cgroup hierarchy can have distinct memory
management BPF programs.

Signed-off-by: Geliang Tang <[email protected]>
Signed-off-by: Hui Zhu <[email protected]>
---
 include/linux/memcontrol.h |  1 +
 mm/bpf_memcontrol.c        | 82 ++++++++++++++++++++++++++------------
 2 files changed, 57 insertions(+), 26 deletions(-)

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 24c4df864401..98c16e8dcd5b 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -354,6 +354,7 @@ struct mem_cgroup {
 
 #ifdef CONFIG_BPF_SYSCALL
        struct memcg_bpf_ops *bpf_ops;
+       u32 bpf_ops_flags;
 #endif
 
        struct mem_cgroup_per_node *nodeinfo[];
diff --git a/mm/bpf_memcontrol.c b/mm/bpf_memcontrol.c
index e746eb9cbd56..7cd983e350d7 100644
--- a/mm/bpf_memcontrol.c
+++ b/mm/bpf_memcontrol.c
@@ -213,6 +213,7 @@ void memcontrol_bpf_online(struct mem_cgroup *memcg)
                goto out;
 
        WRITE_ONCE(memcg->bpf_ops, ops);
+       memcg->bpf_ops_flags = parent_memcg->bpf_ops_flags;
 
        /*
         * If the BPF program implements it, call the online handler to
@@ -338,33 +339,19 @@ static int bpf_memcg_ops_init_member(const struct 
btf_type *t,
        return 0;
 }
 
-/**
- * clean_memcg_bpf_ops - Clear BPF ops from a memory cgroup hierarchy
- * @memcg: Root memory cgroup to start from
- * @ops: The specific BPF ops to remove
- *
- * Walks the cgroup hierarchy and clears bpf_ops for any cgroup that
- * matches @ops.
- */
-static void clean_memcg_bpf_ops(struct mem_cgroup *memcg,
-                               struct memcg_bpf_ops *ops)
-{
-       struct mem_cgroup *iter = NULL;
-
-       while ((iter = mem_cgroup_iter(memcg, iter, NULL))) {
-               if (READ_ONCE(iter->bpf_ops) == ops)
-                       WRITE_ONCE(iter->bpf_ops, NULL);
-       }
-}
-
 static int bpf_memcg_ops_reg(void *kdata, struct bpf_link *link)
 {
        struct bpf_struct_ops_link *ops_link
                = container_of(link, struct bpf_struct_ops_link, link);
-       struct memcg_bpf_ops *ops = kdata;
+       struct memcg_bpf_ops *ops = kdata, *old_ops;
        struct mem_cgroup *memcg, *iter = NULL;
        int err = 0;
 
+       if (ops_link->flags & ~BPF_F_ALLOW_OVERRIDE) {
+               pr_err("attach only support BPF_F_ALLOW_OVERRIDE\n");
+               return -EOPNOTSUPP;
+       }
+
        memcg = mem_cgroup_get_from_ino(ops_link->cgroup_id);
        if (!memcg)
                return -ENOENT;
@@ -372,16 +359,41 @@ static int bpf_memcg_ops_reg(void *kdata, struct bpf_link 
*link)
                return PTR_ERR(memcg);
 
        cgroup_lock();
+
+       /*
+        * Check if memcg has bpf_ops and whether it is inherited from
+        * parent.
+        * If inherited and BPF_F_ALLOW_OVERRIDE is set, allow override.
+        */
+       old_ops = READ_ONCE(memcg->bpf_ops);
+       if (old_ops) {
+               struct mem_cgroup *parent_memcg = parent_mem_cgroup(memcg);
+
+               if (!parent_memcg ||
+                   !(memcg->bpf_ops_flags & BPF_F_ALLOW_OVERRIDE) ||
+                   READ_ONCE(parent_memcg->bpf_ops) != old_ops) {
+                       err = -EBUSY;
+                       goto unlock_out;
+               }
+       }
+
+       /* Check for incompatible bpf_ops in descendants. */
        while ((iter = mem_cgroup_iter(memcg, iter, NULL))) {
-               if (READ_ONCE(iter->bpf_ops)) {
-                       mem_cgroup_iter_break(memcg, iter);
+               struct memcg_bpf_ops *iter_ops = READ_ONCE(iter->bpf_ops);
+
+               if (iter_ops && iter_ops != old_ops) {
+                       /* cannot override existing bpf_ops of sub-cgroup. */
                        err = -EBUSY;
-                       break;
+                       goto unlock_out;
                }
+       }
+
+       while ((iter = mem_cgroup_iter(memcg, iter, NULL))) {
                WRITE_ONCE(iter->bpf_ops, ops);
+               iter->bpf_ops_flags = ops_link->flags;
        }
-       if (err)
-               clean_memcg_bpf_ops(memcg, ops);
+
+unlock_out:
        cgroup_unlock();
 
        mem_cgroup_put(memcg);
@@ -395,13 +407,31 @@ static void bpf_memcg_ops_unreg(void *kdata, struct 
bpf_link *link)
                = container_of(link, struct bpf_struct_ops_link, link);
        struct memcg_bpf_ops *ops = kdata;
        struct mem_cgroup *memcg;
+       struct mem_cgroup *iter;
+       struct memcg_bpf_ops *parent_bpf_ops = NULL;
+       u32 parent_bpf_ops_flags = 0;
 
        memcg = mem_cgroup_get_from_ino(ops_link->cgroup_id);
        if (IS_ERR_OR_NULL(memcg))
                goto out;
 
        cgroup_lock();
-       clean_memcg_bpf_ops(memcg, ops);
+
+       /* Get the parent bpf_ops and bpf_ops_flags */
+       iter = parent_mem_cgroup(memcg);
+       if (iter) {
+               parent_bpf_ops = READ_ONCE(iter->bpf_ops);
+               parent_bpf_ops_flags = iter->bpf_ops_flags;
+       }
+
+       iter = NULL;
+       while ((iter = mem_cgroup_iter(memcg, iter, NULL))) {
+               if (READ_ONCE(iter->bpf_ops) == ops) {
+                       WRITE_ONCE(iter->bpf_ops, parent_bpf_ops);
+                       iter->bpf_ops_flags = parent_bpf_ops_flags;
+               }
+       }
+
        cgroup_unlock();
 
        mem_cgroup_put(memcg);
-- 
2.43.0


Reply via email to