Track per-task recursion depth using a simple hashtable keyed by PID.
Entry/exit handlers update the depth, triggering only at the configured
recursion level.

Signed-off-by: Jinchao Wang <[email protected]>
---
 mm/kstackwatch/stack.c | 100 ++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 98 insertions(+), 2 deletions(-)

diff --git a/mm/kstackwatch/stack.c b/mm/kstackwatch/stack.c
index 3ea0f9de698e..669876057f0b 100644
--- a/mm/kstackwatch/stack.c
+++ b/mm/kstackwatch/stack.c
@@ -3,6 +3,8 @@
 
 #include <linux/atomic.h>
 #include <linux/fprobe.h>
+#include <linux/hash.h>
+#include <linux/hashtable.h>
 #include <linux/kprobes.h>
 #include <linux/printk.h>
 #include <linux/spinlock.h>
@@ -15,6 +17,83 @@ static struct fprobe exit_probe;
 static atomic_t ksw_stack_pid = ATOMIC_INIT(INVALID_PID);
 #define MAX_CANARY_SEARCH_STEPS 128
 
+struct depth_entry {
+       pid_t pid;
+       int depth; /* starts from 0 */
+       struct hlist_node node;
+};
+
+#define DEPTH_HASH_BITS 8
+#define DEPTH_HASH_SIZE BIT(DEPTH_HASH_BITS)
+static DEFINE_HASHTABLE(depth_hash, DEPTH_HASH_BITS);
+static DEFINE_SPINLOCK(depth_hash_lock);
+
+static int get_recursive_depth(void)
+{
+       struct depth_entry *entry;
+       pid_t pid = current->pid;
+       int depth = 0;
+
+       spin_lock(&depth_hash_lock);
+       hash_for_each_possible(depth_hash, entry, node, pid) {
+               if (entry->pid == pid) {
+                       depth = entry->depth;
+                       break;
+               }
+       }
+       spin_unlock(&depth_hash_lock);
+       return depth;
+}
+
+static void set_recursive_depth(int depth)
+{
+       struct depth_entry *entry;
+       pid_t pid = current->pid;
+       bool found = false;
+
+       spin_lock(&depth_hash_lock);
+       hash_for_each_possible(depth_hash, entry, node, pid) {
+               if (entry->pid == pid) {
+                       entry->depth = depth;
+                       found = true;
+                       break;
+               }
+       }
+
+       if (found) {
+               // last exit handler
+               if (depth == 0) {
+                       hash_del(&entry->node);
+                       kfree(entry);
+               }
+               goto unlock;
+       }
+
+       WARN_ONCE(depth != 1, "new entry depth %d should be 1", depth);
+       entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
+       if (entry) {
+               entry->pid = pid;
+               entry->depth = depth;
+               hash_add(depth_hash, &entry->node, pid);
+       }
+unlock:
+       spin_unlock(&depth_hash_lock);
+}
+
+static void reset_recursive_depth(void)
+{
+       struct depth_entry *entry;
+       struct hlist_node *tmp;
+       int bkt;
+
+       spin_lock(&depth_hash_lock);
+       hash_for_each_safe(depth_hash, bkt, tmp, entry, node) {
+               hash_del(&entry->node);
+               kfree(entry);
+       }
+       spin_unlock(&depth_hash_lock);
+}
+
 static unsigned long ksw_find_stack_canary_addr(struct pt_regs *regs)
 {
        unsigned long *stack_ptr, *stack_end, *stack_base;
@@ -109,8 +188,15 @@ static void ksw_stack_entry_handler(struct kprobe *p, 
struct pt_regs *regs,
 {
        u64 watch_addr;
        u64 watch_len;
+       int cur_depth;
        int ret;
 
+       cur_depth = get_recursive_depth();
+       set_recursive_depth(cur_depth + 1);
+
+       if (cur_depth != ksw_get_config()->depth)
+               return;
+
        if (atomic_cmpxchg(&ksw_stack_pid, INVALID_PID, current->pid) !=
            INVALID_PID)
                return;
@@ -126,8 +212,8 @@ static void ksw_stack_entry_handler(struct kprobe *p, 
struct pt_regs *regs,
        ret = ksw_watch_on(watch_addr, watch_len);
        if (ret) {
                atomic_set(&ksw_stack_pid, INVALID_PID);
-               pr_err("failed to watch on addr:0x%llx len:%llu %d\n",
-                      watch_addr, watch_len, ret);
+               pr_err("failed to watch on depth:%d addr:0x%llx len:%llu %d\n",
+                      cur_depth, watch_addr, watch_len, ret);
                return;
        }
 }
@@ -136,6 +222,14 @@ static void ksw_stack_exit_handler(struct fprobe *fp, 
unsigned long ip,
                                   unsigned long ret_ip,
                                   struct ftrace_regs *regs, void *data)
 {
+       int cur_depth;
+
+       cur_depth = get_recursive_depth() - 1;
+       set_recursive_depth(cur_depth);
+
+       if (cur_depth != ksw_get_config()->depth)
+               return;
+
        if (atomic_read(&ksw_stack_pid) != current->pid)
                return;
 
@@ -149,6 +243,8 @@ int ksw_stack_init(void)
        int ret;
        char *symbuf = NULL;
 
+       reset_recursive_depth();
+
        memset(&entry_probe, 0, sizeof(entry_probe));
        entry_probe.symbol_name = ksw_get_config()->function;
        entry_probe.offset = ksw_get_config()->ip_offset;
-- 
2.43.0


Reply via email to