Track per-task recursion depth using a simple hashtable keyed by PID. Entry/exit handlers update the depth, triggering only at the configured recursion level.
Signed-off-by: Jinchao Wang <[email protected]> --- mm/kstackwatch/stack.c | 100 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 2 deletions(-) diff --git a/mm/kstackwatch/stack.c b/mm/kstackwatch/stack.c index 3ea0f9de698e..669876057f0b 100644 --- a/mm/kstackwatch/stack.c +++ b/mm/kstackwatch/stack.c @@ -3,6 +3,8 @@ #include <linux/atomic.h> #include <linux/fprobe.h> +#include <linux/hash.h> +#include <linux/hashtable.h> #include <linux/kprobes.h> #include <linux/printk.h> #include <linux/spinlock.h> @@ -15,6 +17,83 @@ static struct fprobe exit_probe; static atomic_t ksw_stack_pid = ATOMIC_INIT(INVALID_PID); #define MAX_CANARY_SEARCH_STEPS 128 +struct depth_entry { + pid_t pid; + int depth; /* starts from 0 */ + struct hlist_node node; +}; + +#define DEPTH_HASH_BITS 8 +#define DEPTH_HASH_SIZE BIT(DEPTH_HASH_BITS) +static DEFINE_HASHTABLE(depth_hash, DEPTH_HASH_BITS); +static DEFINE_SPINLOCK(depth_hash_lock); + +static int get_recursive_depth(void) +{ + struct depth_entry *entry; + pid_t pid = current->pid; + int depth = 0; + + spin_lock(&depth_hash_lock); + hash_for_each_possible(depth_hash, entry, node, pid) { + if (entry->pid == pid) { + depth = entry->depth; + break; + } + } + spin_unlock(&depth_hash_lock); + return depth; +} + +static void set_recursive_depth(int depth) +{ + struct depth_entry *entry; + pid_t pid = current->pid; + bool found = false; + + spin_lock(&depth_hash_lock); + hash_for_each_possible(depth_hash, entry, node, pid) { + if (entry->pid == pid) { + entry->depth = depth; + found = true; + break; + } + } + + if (found) { + // last exit handler + if (depth == 0) { + hash_del(&entry->node); + kfree(entry); + } + goto unlock; + } + + WARN_ONCE(depth != 1, "new entry depth %d should be 1", depth); + entry = kmalloc(sizeof(*entry), GFP_ATOMIC); + if (entry) { + entry->pid = pid; + entry->depth = depth; + hash_add(depth_hash, &entry->node, pid); + } +unlock: + spin_unlock(&depth_hash_lock); +} + +static void reset_recursive_depth(void) +{ + struct depth_entry *entry; + struct hlist_node *tmp; + int bkt; + + spin_lock(&depth_hash_lock); + hash_for_each_safe(depth_hash, bkt, tmp, entry, node) { + hash_del(&entry->node); + kfree(entry); + } + spin_unlock(&depth_hash_lock); +} + static unsigned long ksw_find_stack_canary_addr(struct pt_regs *regs) { unsigned long *stack_ptr, *stack_end, *stack_base; @@ -109,8 +188,15 @@ static void ksw_stack_entry_handler(struct kprobe *p, struct pt_regs *regs, { u64 watch_addr; u64 watch_len; + int cur_depth; int ret; + cur_depth = get_recursive_depth(); + set_recursive_depth(cur_depth + 1); + + if (cur_depth != ksw_get_config()->depth) + return; + if (atomic_cmpxchg(&ksw_stack_pid, INVALID_PID, current->pid) != INVALID_PID) return; @@ -126,8 +212,8 @@ static void ksw_stack_entry_handler(struct kprobe *p, struct pt_regs *regs, ret = ksw_watch_on(watch_addr, watch_len); if (ret) { atomic_set(&ksw_stack_pid, INVALID_PID); - pr_err("failed to watch on addr:0x%llx len:%llu %d\n", - watch_addr, watch_len, ret); + pr_err("failed to watch on depth:%d addr:0x%llx len:%llu %d\n", + cur_depth, watch_addr, watch_len, ret); return; } } @@ -136,6 +222,14 @@ static void ksw_stack_exit_handler(struct fprobe *fp, unsigned long ip, unsigned long ret_ip, struct ftrace_regs *regs, void *data) { + int cur_depth; + + cur_depth = get_recursive_depth() - 1; + set_recursive_depth(cur_depth); + + if (cur_depth != ksw_get_config()->depth) + return; + if (atomic_read(&ksw_stack_pid) != current->pid) return; @@ -149,6 +243,8 @@ int ksw_stack_init(void) int ret; char *symbuf = NULL; + reset_recursive_depth(); + memset(&entry_probe, 0, sizeof(entry_probe)); entry_probe.symbol_name = ksw_get_config()->function; entry_probe.offset = ksw_get_config()->ip_offset; -- 2.43.0
