This patch introduces a new program type -- 
BPF_PROG_TYPE_RAW_TRACEPOINT_OVERRIDE.
Program of this type requires an additional parameter -- probe_name, to locate
the target tracepoint probe function registered by register_trace_* in the 
kernel.

This type reuses existing RAW_TRACEPOINT infrastructure, and differs
only when probe_name is specified. In that case, the newly attached
RAW_TRACEPOINT_OVERRIDE program and the target probe function are paired
and stored in a snapshot.

When the BPF program is detached, snapshots are consulted to determine
whether restoration of the original probe function is required.

Signed-off-by: Fuyu Zhao <[email protected]>
---
 include/linux/bpf_types.h       |   2 +
 include/linux/trace_events.h    |   9 ++
 include/linux/tracepoint-defs.h |   6 +
 include/linux/tracepoint.h      |   3 +
 include/uapi/linux/bpf.h        |   2 +
 kernel/bpf/syscall.c            |  35 ++++--
 kernel/trace/bpf_trace.c        |  31 ++++++
 kernel/tracepoint.c             | 190 +++++++++++++++++++++++++++++++-
 8 files changed, 269 insertions(+), 9 deletions(-)

diff --git a/include/linux/bpf_types.h b/include/linux/bpf_types.h
index fa78f49d4a9a..e5cf8a1af6cd 100644
--- a/include/linux/bpf_types.h
+++ b/include/linux/bpf_types.h
@@ -48,6 +48,8 @@ BPF_PROG_TYPE(BPF_PROG_TYPE_RAW_TRACEPOINT_WRITABLE, 
raw_tracepoint_writable,
              struct bpf_raw_tracepoint_args, u64)
 BPF_PROG_TYPE(BPF_PROG_TYPE_TRACING, tracing,
              void *, void *)
+BPF_PROG_TYPE(BPF_PROG_TYPE_RAW_TRACEPOINT_OVERRIDE, raw_tracepoint_override,
+             struct bpf_raw_tracepoint_args, u64)
 #endif
 #ifdef CONFIG_CGROUP_BPF
 BPF_PROG_TYPE(BPF_PROG_TYPE_CGROUP_DEVICE, cg_dev,
diff --git a/include/linux/trace_events.h b/include/linux/trace_events.h
index 04307a19cde3..fcb2d62d0c9f 100644
--- a/include/linux/trace_events.h
+++ b/include/linux/trace_events.h
@@ -768,6 +768,9 @@ int perf_event_query_prog_array(struct perf_event *event, 
void __user *info);
 struct bpf_raw_tp_link;
 int bpf_probe_register(struct bpf_raw_event_map *btp, struct bpf_raw_tp_link 
*link);
 int bpf_probe_unregister(struct bpf_raw_event_map *btp, struct bpf_raw_tp_link 
*link);
+int bpf_probe_override(struct bpf_raw_event_map *btp,
+                      struct bpf_raw_tp_link *link,
+                      const char *probe_name);
 
 struct bpf_raw_event_map *bpf_get_raw_tracepoint(const char *name);
 void bpf_put_raw_tracepoint(struct bpf_raw_event_map *btp);
@@ -805,6 +808,12 @@ static inline int bpf_probe_unregister(struct 
bpf_raw_event_map *btp, struct bpf
 {
        return -EOPNOTSUPP;
 }
+static inline int bpf_probe_override(struct bpf_raw_event_map *btp,
+                                    struct bpf_raw_tp_link *link,
+                                    const char *probe_name)
+{
+       return -EOPNOTSUPP;
+}
 static inline struct bpf_raw_event_map *bpf_get_raw_tracepoint(const char 
*name)
 {
        return NULL;
diff --git a/include/linux/tracepoint-defs.h b/include/linux/tracepoint-defs.h
index aebf0571c736..9d7b1710c0aa 100644
--- a/include/linux/tracepoint-defs.h
+++ b/include/linux/tracepoint-defs.h
@@ -29,6 +29,11 @@ struct tracepoint_func {
        int prio;
 };
 
+struct tracepoint_func_snapshot {
+       struct tracepoint_func orig;
+       struct tracepoint_func override;
+};
+
 struct tracepoint_ext {
        int (*regfunc)(void);
        void (*unregfunc)(void);
@@ -45,6 +50,7 @@ struct tracepoint {
        void *probestub;
        struct tracepoint_func __rcu *funcs;
        struct tracepoint_ext *ext;
+       struct tracepoint_func_snapshot *snapshot;
 };
 
 #ifdef CONFIG_HAVE_ARCH_PREL32_RELOCATIONS
diff --git a/include/linux/tracepoint.h b/include/linux/tracepoint.h
index 826ce3f8e1f8..399001e2afca 100644
--- a/include/linux/tracepoint.h
+++ b/include/linux/tracepoint.h
@@ -50,6 +50,9 @@ tracepoint_probe_register_may_exist(struct tracepoint *tp, 
void *probe,
        return tracepoint_probe_register_prio_may_exist(tp, probe, data,
                                                        
TRACEPOINT_DEFAULT_PRIO);
 }
+extern int
+tracepoint_probe_override(struct tracepoint *tp, void *probe, void *data,
+                         const char *func_replaced);
 extern void
 for_each_kernel_tracepoint(void (*fct)(struct tracepoint *tp, void *priv),
                void *priv);
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 233de8677382..cd3d889fe634 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -1071,6 +1071,7 @@ enum bpf_prog_type {
        BPF_PROG_TYPE_SK_LOOKUP,
        BPF_PROG_TYPE_SYSCALL, /* a program that can execute syscalls */
        BPF_PROG_TYPE_NETFILTER,
+       BPF_PROG_TYPE_RAW_TRACEPOINT_OVERRIDE,
        __MAX_BPF_PROG_TYPE
 };
 
@@ -1707,6 +1708,7 @@ union bpf_attr {
                __u32           prog_fd;
                __u32           :32;
                __aligned_u64   cookie;
+               __aligned_u64   probe_name;
        } raw_tracepoint;
 
        struct { /* anonymous struct for BPF_BTF_LOAD */
diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
index 3f178a0f8eb1..e360062db34e 100644
--- a/kernel/bpf/syscall.c
+++ b/kernel/bpf/syscall.c
@@ -4092,14 +4092,16 @@ static int bpf_perf_link_attach(const union bpf_attr 
*attr, struct bpf_prog *pro
 #endif /* CONFIG_PERF_EVENTS */
 
 static int bpf_raw_tp_link_attach(struct bpf_prog *prog,
-                                 const char __user *user_tp_name, u64 cookie,
+                                 const char __user *user_tp_name,
+                                 const char __user *user_probe_name,
+                                 u64 cookie,
                                  enum bpf_attach_type attach_type)
 {
        struct bpf_link_primer link_primer;
        struct bpf_raw_tp_link *link;
        struct bpf_raw_event_map *btp;
-       const char *tp_name;
-       char buf[128];
+       const char *tp_name, *probe_name;
+       char buf[128], probe[128];
        int err;
 
        switch (prog->type) {
@@ -4124,6 +4126,17 @@ static int bpf_raw_tp_link_attach(struct bpf_prog *prog,
                buf[sizeof(buf) - 1] = 0;
                tp_name = buf;
                break;
+       case BPF_PROG_TYPE_RAW_TRACEPOINT_OVERRIDE:
+               if (strncpy_from_user(buf, user_tp_name, sizeof(buf) - 1) < 0)
+                       return -EFAULT;
+               buf[sizeof(buf) - 1] = 0;
+               tp_name = buf;
+
+               if (strncpy_from_user(probe, user_probe_name, sizeof(probe) - 
1) < 0)
+                       return -EFAULT;
+               probe[sizeof(probe) - 1] = 0;
+               probe_name = probe;
+               break;
        default:
                return -EINVAL;
        }
@@ -4149,7 +4162,10 @@ static int bpf_raw_tp_link_attach(struct bpf_prog *prog,
                goto out_put_btp;
        }
 
-       err = bpf_probe_register(link->btp, link);
+       if (prog->type == BPF_PROG_TYPE_RAW_TRACEPOINT_OVERRIDE)
+               err = bpf_probe_override(link->btp, link, probe_name);
+       else
+               err = bpf_probe_register(link->btp, link);
        if (err) {
                bpf_link_cleanup(&link_primer);
                goto out_put_btp;
@@ -4162,12 +4178,12 @@ static int bpf_raw_tp_link_attach(struct bpf_prog *prog,
        return err;
 }
 
-#define BPF_RAW_TRACEPOINT_OPEN_LAST_FIELD raw_tracepoint.cookie
+#define BPF_RAW_TRACEPOINT_OPEN_LAST_FIELD raw_tracepoint.probe_name
 
 static int bpf_raw_tracepoint_open(const union bpf_attr *attr)
 {
        struct bpf_prog *prog;
-       void __user *tp_name;
+       void __user *tp_name, *probe_name;
        __u64 cookie;
        int fd;
 
@@ -4180,7 +4196,9 @@ static int bpf_raw_tracepoint_open(const union bpf_attr 
*attr)
 
        tp_name = u64_to_user_ptr(attr->raw_tracepoint.name);
        cookie = attr->raw_tracepoint.cookie;
-       fd = bpf_raw_tp_link_attach(prog, tp_name, cookie, 
prog->expected_attach_type);
+       probe_name = u64_to_user_ptr(attr->raw_tracepoint.probe_name);
+       fd = bpf_raw_tp_link_attach(prog, tp_name, probe_name,
+                                   cookie, prog->expected_attach_type);
        if (fd < 0)
                bpf_prog_put(prog);
        return fd;
@@ -5565,7 +5583,8 @@ static int link_create(union bpf_attr *attr, bpfptr_t 
uattr)
                        goto out;
                }
                if (prog->expected_attach_type == BPF_TRACE_RAW_TP)
-                       ret = bpf_raw_tp_link_attach(prog, NULL, 
attr->link_create.tracing.cookie,
+                       ret = bpf_raw_tp_link_attach(prog, NULL, NULL,
+                                                    
attr->link_create.tracing.cookie,
                                                     
attr->link_create.attach_type);
                else if (prog->expected_attach_type == BPF_TRACE_ITER)
                        ret = bpf_iter_link_attach(attr, uattr, prog);
diff --git a/kernel/trace/bpf_trace.c b/kernel/trace/bpf_trace.c
index 606007c387c5..1e965517ba05 100644
--- a/kernel/trace/bpf_trace.c
+++ b/kernel/trace/bpf_trace.c
@@ -1998,6 +1998,14 @@ const struct bpf_verifier_ops 
raw_tracepoint_writable_verifier_ops = {
 const struct bpf_prog_ops raw_tracepoint_writable_prog_ops = {
 };
 
+const struct bpf_verifier_ops raw_tracepoint_override_verifier_ops = {
+       .get_func_proto  = raw_tp_prog_func_proto,
+       .is_valid_access = raw_tp_writable_prog_is_valid_access,
+};
+
+const struct bpf_prog_ops raw_tracepoint_override_prog_ops = {
+};
+
 static bool pe_prog_is_valid_access(int off, int size, enum bpf_access_type 
type,
                                    const struct bpf_prog *prog,
                                    struct bpf_insn_access_aux *info)
@@ -2307,6 +2315,29 @@ BPF_TRACE_DEFN_x(10);
 BPF_TRACE_DEFN_x(11);
 BPF_TRACE_DEFN_x(12);
 
+int bpf_probe_override(struct bpf_raw_event_map *btp,
+                      struct bpf_raw_tp_link *link,
+                      const char *probe_name)
+{
+       struct tracepoint *tp = btp->tp;
+       struct bpf_prog *prog = link->link.prog;
+
+       if (!probe_name)
+               return -EINVAL;
+
+       /*
+        * check that program doesn't access arguments beyond what's
+        * available in this tracepoint
+        */
+       if (prog->aux->max_ctx_offset > btp->num_args * sizeof(u64))
+               return -EINVAL;
+
+       if (prog->aux->max_tp_access > btp->writable_size)
+               return -EINVAL;
+
+       return tracepoint_probe_override(tp, (void *)btp->bpf_func, link, 
probe_name);
+}
+
 int bpf_probe_register(struct bpf_raw_event_map *btp, struct bpf_raw_tp_link 
*link)
 {
        struct tracepoint *tp = btp->tp;
diff --git a/kernel/tracepoint.c b/kernel/tracepoint.c
index 62719d2941c9..3b8317306edc 100644
--- a/kernel/tracepoint.c
+++ b/kernel/tracepoint.c
@@ -14,6 +14,7 @@
 #include <linux/sched/signal.h>
 #include <linux/sched/task.h>
 #include <linux/static_key.h>
+#include <linux/kallsyms.h>
 
 enum tp_func_state {
        TP_FUNC_0,
@@ -130,6 +131,121 @@ static void debug_print_probes(struct tracepoint_func 
*funcs)
                printk(KERN_DEBUG "Probe %d : %pSb\n", i, funcs[i].func);
 }
 
+static struct tracepoint_func *
+find_func_to_override(struct tracepoint_func *funcs,
+                     unsigned long probe_addr)
+{
+       int iter;
+
+       if (!funcs)
+               return NULL;
+
+       for (iter = 0; funcs[iter].func; iter++) {
+               if ((unsigned long)funcs[iter].func == probe_addr)
+                       return &(funcs[iter]);
+       }
+
+       return NULL;
+}
+
+static struct tracepoint_func_snapshot *
+find_func_snapshot(struct tracepoint_func_snapshot **ss,
+                  struct tracepoint_func *func,
+                  bool *is_override)
+{
+       int iter;
+       struct tracepoint_func_snapshot *shots;
+
+       shots = *ss;
+       if (!shots)
+               return NULL;
+
+       for (iter = 0; shots[iter].override.func; iter++) {
+               if (shots[iter].override.func == func->func &&
+                  shots[iter].override.data == func->data) {
+                       *is_override = true;
+                       return &(shots[iter]);
+               }
+
+               if (shots[iter].orig.func == func->func &&
+                  shots[iter].orig.data == func->data) {
+                       *is_override = false;
+                       return &(shots[iter]);
+               }
+       }
+
+       return NULL;
+}
+
+static void drop_func_snapshot(struct tracepoint_func_snapshot **ss,
+                              struct tracepoint_func_snapshot *drop)
+{
+       struct tracepoint_func_snapshot *old, *new;
+       int nr_snapshots;       /* Counter for snapshots */
+       int iter;               /* Iterate over old snapshots */
+       int idx = 0;            /* Index of snapshot to drop */
+
+       old = *ss;
+       if (!old)
+               return;
+
+       for (nr_snapshots = 0; old[nr_snapshots].override.func; nr_snapshots++) 
{
+               if (&(old[nr_snapshots]) == drop)
+                       idx = nr_snapshots;
+       }
+
+       if (nr_snapshots == 0) {
+               kfree(old);
+               *ss = NULL;
+               return;
+       }
+
+       new = kmalloc_array(nr_snapshots, sizeof(struct 
tracepoint_func_snapshot), GFP_KERNEL);
+       if (!new) {
+               for (iter = idx; iter < nr_snapshots - 1; iter++)
+                       old[iter] = old[iter + 1];
+               memset(&(old[nr_snapshots - 1]), 0, sizeof(struct 
tracepoint_func_snapshot));
+       } else {
+               int j = 0;
+
+               for (iter = 0; iter < nr_snapshots; iter++) {
+                       if (iter != idx)
+                               new[j++] = old[iter];
+               }
+               kfree(old);
+               *ss = new;
+       }
+}
+
+static int save_func_snapshot(struct tracepoint_func_snapshot **ss,
+                             struct tracepoint_func *new_func,
+                             struct tracepoint_func *old_func)
+{
+       struct tracepoint_func_snapshot *old, *new;
+       int nr_shots = 0;       /* Counter for old snapshots */
+       int total;              /* Total count of new snapshots */
+
+       old = *ss;
+       if (old)
+               while (old[nr_shots].override.func)
+                       nr_shots++;
+
+       /* + 2 : one for new snapshot, one for NULL snapshot */
+       total = nr_shots + 2;
+       new = kmalloc_array(total, sizeof(struct tracepoint_func_snapshot), 
GFP_KERNEL);
+       if (!new)
+               return -ENOMEM;
+
+       memcpy(new, old, nr_shots * sizeof(struct tracepoint_func_snapshot));
+       new[nr_shots].orig = *old_func;
+       new[nr_shots].override = *new_func;
+       new[nr_shots + 1].override.func = NULL;
+
+       *ss = new;
+       kfree(old);
+       return 0;
+}
+
 static struct tracepoint_func *
 func_add(struct tracepoint_func **funcs, struct tracepoint_func *tp_func,
         int prio)
@@ -412,6 +528,52 @@ static int tracepoint_remove_func(struct tracepoint *tp,
        return 0;
 }
 
+static int tracepoint_override_func(struct tracepoint *tp,
+                                   struct tracepoint_func *func,
+                                   struct tracepoint_func *func_override)
+{
+       int ret = tracepoint_remove_func(tp, func);
+
+       return ret ? : tracepoint_add_func(tp, func_override,
+                                          func_override->prio, false);
+}
+
+static int tracepoint_restore_func(struct tracepoint *tp,
+                                  struct tracepoint_func *func,
+                                  struct tracepoint_func *func_restore)
+{
+       int ret = tracepoint_remove_func(tp, func);
+
+       return ret ? : tracepoint_add_func(tp, func_restore,
+                                          func_restore->prio, false);
+}
+
+int tracepoint_probe_override(struct tracepoint *tp, void *probe,
+                             void *data, const char *probe_name)
+{
+       struct tracepoint_func tp_func;
+       struct tracepoint_func *target_func;
+       unsigned long probe_addr;
+       int ret;
+
+       probe_addr = kallsyms_lookup_name(probe_name);
+       mutex_lock(&tracepoints_mutex);
+       target_func = find_func_to_override(tp->funcs, probe_addr);
+       if (!target_func)
+               return -ESRCH;
+       tp_func.func = probe;
+       tp_func.data = data;
+       tp_func.prio = target_func->prio;
+       ret = save_func_snapshot(&(tp->snapshot), &tp_func, target_func);
+       if (ret)
+               goto unlock;
+
+       ret = tracepoint_override_func(tp, target_func, &tp_func);
+unlock:
+       mutex_unlock(&tracepoints_mutex);
+       return ret;
+}
+
 /**
  * tracepoint_probe_register_prio_may_exist -  Connect a probe to a tracepoint 
with priority
  * @tp: tracepoint
@@ -496,12 +658,38 @@ EXPORT_SYMBOL_GPL(tracepoint_probe_register);
 int tracepoint_probe_unregister(struct tracepoint *tp, void *probe, void *data)
 {
        struct tracepoint_func tp_func;
+       struct tracepoint_func_snapshot *shot;
        int ret;
+       bool is_override;       /* whether probe is an overriding func */
 
        mutex_lock(&tracepoints_mutex);
        tp_func.func = probe;
        tp_func.data = data;
-       ret = tracepoint_remove_func(tp, &tp_func);
+
+       shot = find_func_snapshot(&(tp->snapshot), &tp_func, &is_override);
+       if (!shot) {
+               ret = tracepoint_remove_func(tp, &tp_func);
+       } else {
+               /* unregister probe rengistered by raw_tracepoint_open,
+                * restore to original tp_func.
+                *
+                * 1. restore orig func from snapshot.
+                * 2. remove snapshot.
+                */
+               if (is_override)
+                       ret = tracepoint_restore_func(tp, &tp_func, 
&(shot->orig));
+               /* unregister orig probe registered by register_trace_*.
+                *
+                * 1. remove curr probe func(registered by raw_tracepoint_open)
+                *    from tp->funcs.
+                * 2. remove snapshot.
+                */
+               else
+                       ret = tracepoint_remove_func(tp, &(shot->override));
+               if (!ret)
+                       drop_func_snapshot(&(tp->snapshot), shot);
+       }
+
        mutex_unlock(&tracepoints_mutex);
        return ret;
 }
-- 
2.43.0


Reply via email to