In order to put a bound on the uretprobe_srcu critical section, add a
timer to uprobe_task. Upon every RI added or removed the timer is
pushed forward to now + 1s. If the timer were ever to fire, it would
convert the SRCU 'reference' to a refcount reference if possible.

Signed-off-by: Peter Zijlstra (Intel) <pet...@infradead.org>
---
 include/linux/uprobes.h |    8 +++++
 kernel/events/uprobes.c |   67 ++++++++++++++++++++++++++++++++++++++++++++----
 2 files changed, 69 insertions(+), 6 deletions(-)

--- a/include/linux/uprobes.h
+++ b/include/linux/uprobes.h
@@ -15,6 +15,7 @@
 #include <linux/rbtree.h>
 #include <linux/types.h>
 #include <linux/wait.h>
+#include <linux/timer.h>
 
 struct vm_area_struct;
 struct mm_struct;
@@ -79,6 +80,10 @@ struct uprobe_task {
        struct return_instance          *return_instances;
        unsigned int                    depth;
        unsigned int                    active_srcu_idx;
+
+       struct timer_list               ri_timer;
+       struct callback_head            ri_task_work;
+       struct task_struct              *task;
 };
 
 struct return_instance {
@@ -86,7 +91,8 @@ struct return_instance {
        unsigned long           func;
        unsigned long           stack;          /* stack pointer */
        unsigned long           orig_ret_vaddr; /* original return address */
-       bool                    chained;        /* true, if instance is nested 
*/
+       u8                      chained;        /* true, if instance is nested 
*/
+       u8                      has_ref;
        int                     srcu_idx;
 
        struct return_instance  *next;          /* keep as stack */
--- a/kernel/events/uprobes.c
+++ b/kernel/events/uprobes.c
@@ -1761,7 +1761,12 @@ unsigned long uprobe_get_trap_addr(struc
 static struct return_instance *free_ret_instance(struct return_instance *ri)
 {
        struct return_instance *next = ri->next;
-       __srcu_read_unlock(&uretprobes_srcu, ri->srcu_idx);
+       if (ri->uprobe) {
+               if (ri->has_ref)
+                       put_uprobe(ri->uprobe);
+               else
+                       __srcu_read_unlock(&uretprobes_srcu, ri->srcu_idx);
+       }
        kfree(ri);
        return next;
 }
@@ -1785,11 +1790,48 @@ void uprobe_free_utask(struct task_struc
        while (ri)
                ri = free_ret_instance(ri);
 
+       timer_delete_sync(&utask->ri_timer);
+       task_work_cancel(utask->task, &utask->ri_task_work);
        xol_free_insn_slot(t);
        kfree(utask);
        t->utask = NULL;
 }
 
+static void return_instance_task_work(struct callback_head *head)
+{
+       struct uprobe_task *utask = container_of(head, struct uprobe_task, 
ri_task_work);
+       struct return_instance *ri;
+
+       for (ri = utask->return_instances; ri; ri = ri->next) {
+               if (!ri->uprobe)
+                       continue;
+               if (ri->has_ref)
+                       continue;
+               if (refcount_inc_not_zero(&ri->uprobe->ref))
+                       ri->has_ref = true;
+               else
+                       ri->uprobe = NULL;
+               __srcu_read_unlock(&uretprobes_srcu, ri->srcu_idx);
+       }
+}
+
+static void return_instance_timer(struct timer_list *timer)
+{
+       struct uprobe_task *utask = container_of(timer, struct uprobe_task, 
ri_timer);
+       task_work_add(utask->task, &utask->ri_task_work, TWA_SIGNAL);
+}
+
+static struct uprobe_task *alloc_utask(struct task_struct *task)
+{
+       struct uprobe_task *utask = kzalloc(sizeof(struct uprobe_task), 
GFP_KERNEL);
+       if (!utask)
+               return NULL;
+       timer_setup(&utask->ri_timer, return_instance_timer, 0);
+       init_task_work(&utask->ri_task_work, return_instance_task_work);
+       utask->task = task;
+       return utask;
+}
+
 /*
  * Allocate a uprobe_task object for the task if necessary.
  * Called when the thread hits a breakpoint.
@@ -1801,7 +1843,7 @@ void uprobe_free_utask(struct task_struc
 static struct uprobe_task *get_utask(void)
 {
        if (!current->utask)
-               current->utask = kzalloc(sizeof(struct uprobe_task), 
GFP_KERNEL);
+               current->utask = alloc_utask(current);
        return current->utask;
 }
 
@@ -1810,7 +1852,7 @@ static int dup_utask(struct task_struct
        struct uprobe_task *n_utask;
        struct return_instance **p, *o, *n;
 
-       n_utask = kzalloc(sizeof(struct uprobe_task), GFP_KERNEL);
+       n_utask = alloc_utask(t);
        if (!n_utask)
                return -ENOMEM;
        t->utask = n_utask;
@@ -1822,13 +1864,20 @@ static int dup_utask(struct task_struct
                        return -ENOMEM;
 
                *n = *o;
-               __srcu_clone_read_lock(&uretprobes_srcu, n->srcu_idx);
+               if (n->uprobe) {
+                       if (n->has_ref)
+                               get_uprobe(n->uprobe);
+                       else
+                               __srcu_clone_read_lock(&uretprobes_srcu, 
n->srcu_idx);
+               }
                n->next = NULL;
 
                *p = n;
                p = &n->next;
                n_utask->depth++;
        }
+       if (n_utask->return_instances)
+               mod_timer(&n_utask->ri_timer, jiffies + HZ);
 
        return 0;
 }
@@ -1967,6 +2016,7 @@ static void prepare_uretprobe(struct upr
 
        ri->srcu_idx = __srcu_read_lock(&uretprobes_srcu);
        ri->uprobe = uprobe;
+       ri->has_ref = 0;
        ri->func = instruction_pointer(regs);
        ri->stack = user_stack_pointer(regs);
        ri->orig_ret_vaddr = orig_ret_vaddr;
@@ -1976,6 +2026,8 @@ static void prepare_uretprobe(struct upr
        ri->next = utask->return_instances;
        utask->return_instances = ri;
 
+       mod_timer(&utask->ri_timer, jiffies + HZ);
+
        return;
 
 err_mem:
@@ -2204,6 +2256,9 @@ handle_uretprobe_chain(struct return_ins
        struct uprobe *uprobe = ri->uprobe;
        struct uprobe_consumer *uc;
 
+       if (!uprobe)
+               return;
+
        guard(srcu)(&uprobes_srcu);
 
        for_each_consumer_rcu(uc, uprobe->consumers) {
@@ -2250,8 +2305,10 @@ static void handle_trampoline(struct pt_
 
                instruction_pointer_set(regs, ri->orig_ret_vaddr);
                do {
-                       if (valid)
+                       if (valid) {
                                handle_uretprobe_chain(ri, regs);
+                               mod_timer(&utask->ri_timer, jiffies + HZ);
+                       }
                        ri = free_ret_instance(ri);
                        utask->depth--;
                } while (ri != next);



Reply via email to