On Mon, Jan 18, 2021 at 03:12:21PM +0100, Frederic Weisbecker wrote:
> +#ifdef CONFIG_PREEMPT_DYNAMIC
> +DEFINE_STATIC_CALL(preempt_schedule, __preempt_schedule_func());
> +EXPORT_STATIC_CALL(preempt_schedule);
> +#endif

> +#ifdef CONFIG_PREEMPT_DYNAMIC
> +DEFINE_STATIC_CALL(preempt_schedule_notrace, 
> __preempt_schedule_notrace_func());
> +EXPORT_STATIC_CALL(preempt_schedule_notrace);
> +#endif

So one of the things I hates most of this is that is allows 'random'
modules to hijack the preemption by rewriting these callsites. Once you
export the key, we've lost.

I've tried a number of things, but this is the only one I could come up
with that actually stands a chance against malicious modules (vbox and
the like).

It's somewhat elaborate, but afaict it actually works.

---

--- a/arch/x86/include/asm/preempt.h
+++ b/arch/x86/include/asm/preempt.h
@@ -114,7 +114,7 @@ DECLARE_STATIC_CALL(preempt_schedule, __
 
 #define __preempt_schedule() \
 do { \
-       __ADDRESSABLE(STATIC_CALL_KEY(preempt_schedule)); \
+       __STATIC_CALL_MOD_ADDRESSABLE(preempt_schedule); \
        asm volatile ("call " STATIC_CALL_TRAMP_STR(preempt_schedule) : 
ASM_CALL_CONSTRAINT); \
 } while (0)
 
@@ -127,7 +127,7 @@ DECLARE_STATIC_CALL(preempt_schedule_not
 
 #define __preempt_schedule_notrace() \
 do { \
-       __ADDRESSABLE(STATIC_CALL_KEY(preempt_schedule_notrace)); \
+       __STATIC_CALL_MOD_ADDRESSABLE(preempt_schedule_notrace); \
        asm volatile ("call " STATIC_CALL_TRAMP_STR(preempt_schedule_notrace) : 
ASM_CALL_CONSTRAINT); \
 } while (0)
 
--- a/include/linux/kernel.h
+++ b/include/linux/kernel.h
@@ -93,7 +93,7 @@ DECLARE_STATIC_CALL(might_resched, __con
 
 static __always_inline void might_resched(void)
 {
-       static_call(might_resched)();
+       static_call_mod(might_resched)();
 }
 
 #else
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -1880,7 +1880,7 @@ DECLARE_STATIC_CALL(cond_resched, __cond
 
 static __always_inline int _cond_resched(void)
 {
-       return static_call(cond_resched)();
+       return static_call_mod(cond_resched)();
 }
 
 #else
--- a/include/linux/static_call.h
+++ b/include/linux/static_call.h
@@ -107,6 +107,10 @@ extern void arch_static_call_transform(v
 
 #define STATIC_CALL_TRAMP_ADDR(name) &STATIC_CALL_TRAMP(name)
 
+#define static_call_register(name) \
+       __static_call_register(&STATIC_CALL_KEY(name), \
+                              &STATIC_CALL_TRAMP(name))
+
 #else
 #define STATIC_CALL_TRAMP_ADDR(name) NULL
 #endif
@@ -138,6 +142,7 @@ struct static_call_key {
        };
 };
 
+extern int __static_call_register(struct static_call_key *key, void *tramp);
 extern void __static_call_update(struct static_call_key *key, void *tramp, 
void *func);
 extern int static_call_mod_init(struct module *mod);
 extern int static_call_text_reserved(void *start, void *end);
@@ -162,6 +167,9 @@ extern long __static_call_return0(void);
 
 #define static_call_cond(name) (void)__static_call(name)
 
+#define EXPORT_STATIC_CALL_TRAMP(name)                                 \
+       EXPORT_SYMBOL(STATIC_CALL_TRAMP(name))
+
 #define EXPORT_STATIC_CALL(name)                                       \
        EXPORT_SYMBOL(STATIC_CALL_KEY(name));                           \
        EXPORT_SYMBOL(STATIC_CALL_TRAMP(name))
@@ -194,6 +202,11 @@ struct static_call_key {
 
 #define static_call_cond(name) (void)__static_call(name)
 
+static inline int __static_call_register(struct static_call_key *key, void 
*tramp)
+{
+       return 0;
+}
+
 static inline
 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
 {
@@ -213,6 +226,9 @@ static inline long __static_call_return0
        return 0;
 }
 
+#define EXPORT_STATIC_CALL_TRAMP(name)                                 \
+       EXPORT_SYMBOL(STATIC_CALL_TRAMP(name))
+
 #define EXPORT_STATIC_CALL(name)                                       \
        EXPORT_SYMBOL(STATIC_CALL_KEY(name));                           \
        EXPORT_SYMBOL(STATIC_CALL_TRAMP(name))
--- a/include/linux/static_call_types.h
+++ b/include/linux/static_call_types.h
@@ -39,17 +39,39 @@ struct static_call_site {
 
 #ifdef CONFIG_HAVE_STATIC_CALL
 
+#define __raw_static_call(name)        (&STATIC_CALL_TRAMP(name))
+
+#ifdef CONFIG_HAVE_STATIC_CALL_INLINE
+
 /*
  * __ADDRESSABLE() is used to ensure the key symbol doesn't get stripped from
  * the symbol table so that objtool can reference it when it generates the
  * .static_call_sites section.
  */
+#define __STATIC_CALL_ADDRESSABLE(name) \
+       __ADDRESSABLE(STATIC_CALL_KEY(name))
+
 #define __static_call(name)                                            \
 ({                                                                     \
-       __ADDRESSABLE(STATIC_CALL_KEY(name));                           \
-       &STATIC_CALL_TRAMP(name);                                       \
+       __STATIC_CALL_ADDRESSABLE(name);                                \
+       __raw_static_call(name);                                        \
 })
 
+#else /* !CONFIG_HAVE_STATIC_CALL_INLINE */
+
+#define __STATIC_CALL_ADDRESSABLE(name)
+#define __static_call(name)    __raw_static_call(name)
+
+#endif /* CONFIG_HAVE_STATIC_CALL_INLINE */
+
+#ifdef MODULE
+#define __STATIC_CALL_MOD_ADDRESSABLE(name)
+#define static_call_mod(name)  __raw_static_call(name)
+#else
+#define __STATIC_CALL_MOD_ADDRESSABLE(name) __STATIC_CALL_ADDRESSABLE(name)
+#define static_call_mod(name)  __static_call(name)
+#endif
+
 #define static_call(name)      __static_call(name)
 
 #else
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -5268,7 +5268,7 @@ EXPORT_SYMBOL(preempt_schedule);
 
 #ifdef CONFIG_PREEMPT_DYNAMIC
 DEFINE_STATIC_CALL(preempt_schedule, __preempt_schedule_func());
-EXPORT_STATIC_CALL(preempt_schedule);
+EXPORT_STATIC_CALL_TRAMP(preempt_schedule);
 #endif
 
 
@@ -5326,7 +5326,7 @@ EXPORT_SYMBOL_GPL(preempt_schedule_notra
 
 #ifdef CONFIG_PREEMPT_DYNAMIC
 DEFINE_STATIC_CALL(preempt_schedule_notrace, 
__preempt_schedule_notrace_func());
-EXPORT_STATIC_CALL(preempt_schedule_notrace);
+EXPORT_STATIC_CALL_TRAMP(preempt_schedule_notrace);
 #endif
 
 #endif /* CONFIG_PREEMPTION */
@@ -6879,10 +6879,10 @@ EXPORT_SYMBOL(__cond_resched);
 
 #ifdef CONFIG_PREEMPT_DYNAMIC
 DEFINE_STATIC_CALL_RET0(cond_resched, __cond_resched);
-EXPORT_STATIC_CALL(cond_resched);
+EXPORT_STATIC_CALL_TRAMP(cond_resched);
 
 DEFINE_STATIC_CALL_RET0(might_resched, __cond_resched);
-EXPORT_STATIC_CALL(might_resched);
+EXPORT_STATIC_CALL_TRAMP(might_resched);
 #endif
 
 /*
@@ -8096,6 +8096,13 @@ void __init sched_init(void)
 
        init_uclamp();
 
+#ifdef CONFIG_PREEMPT_DYNAMIC
+       static_call_register(cond_resched);
+       static_call_register(might_resched);
+       static_call_register(preempt_schedule);
+       static_call_register(preempt_schedule_notrace);
+#endif
+
        scheduler_running = 1;
 }
 
--- a/kernel/static_call.c
+++ b/kernel/static_call.c
@@ -323,10 +323,85 @@ static int __static_call_mod_text_reserv
        return ret;
 }
 
+struct static_call_ass {
+       struct rb_node node;
+       struct static_call_key *key;
+       unsigned long tramp;
+};
+
+static struct rb_root static_call_asses;
+
+#define __node_2_ass(_n) \
+       rb_entry((_n), struct static_call_ass, node)
+
+static inline bool ass_less(struct rb_node *a, const struct rb_node *b)
+{
+       return __node_2_ass(a)->tramp < __node_2_ass(b)->tramp;
+}
+
+static inline int ass_cmp(const void *a, const struct rb_node *b)
+{
+       if (*(unsigned long *)a < __node_2_ass(b)->tramp)
+               return -1;
+
+       if (*(unsigned long *)a > __node_2_ass(b)->tramp)
+               return 1;
+
+       return 0;
+}
+
+int __static_call_register(struct static_call_key *key, void *tramp)
+{
+       struct static_call_ass *ass = kzalloc(sizeof(*ass), GFP_KERNEL);
+       if (!ass)
+               return -ENOMEM;
+
+       ass->key = key;
+       ass->tramp = (unsigned long)tramp;
+
+       /* trampolines should be aligned */
+       WARN_ON_ONCE(ass->tramp & STATIC_CALL_SITE_FLAGS);
+
+       rb_add(&ass->node, &static_call_asses, &ass_less);
+       return 0;
+}
+
+static struct static_call_ass *static_call_find_ass(unsigned long addr)
+{
+       struct rb_node *node = rb_find(&addr, &static_call_asses, &ass_cmp);
+       if (!node)
+               return NULL;
+       return __node_2_ass(node);
+}
+
 static int static_call_add_module(struct module *mod)
 {
-       return __static_call_init(mod, mod->static_call_sites,
-                                 mod->static_call_sites + 
mod->num_static_call_sites);
+       struct static_call_site *start = mod->static_call_sites;
+       struct static_call_site *stop = start + mod->num_static_call_sites;
+       struct static_call_site *site;
+
+       for (site = start; site != stop; site++) {
+               unsigned long addr = (unsigned long)static_call_key(site);
+               struct static_call_ass *ass;
+
+               /*
+                * Gotta fix up the keys that point to the trampoline.
+                */
+               if (!kernel_text_address(addr))
+                       continue;
+
+               ass = static_call_find_ass(addr);
+               if (!ass) {
+                       pr_warn("Failed to fixup __raw_static_call() usage at: 
%ps\n",
+                               static_call_addr(site));
+                       return -EINVAL;
+               }
+
+               site->key = ((unsigned long)ass->key - (unsigned 
long)&site->key) |
+                           (site->key & STATIC_CALL_SITE_FLAGS);
+       }
+
+       return __static_call_init(mod, start, stop);
 }
 
 static void static_call_del_module(struct module *mod)
@@ -392,6 +467,11 @@ static struct notifier_block static_call
 
 #else
 
+int __static_call_register(struct static_call_key *key, void *tramp)
+{
+       return 0;
+}
+
 static inline int __static_call_mod_text_reserved(void *start, void *end)
 {
        return 0;
--- a/tools/include/linux/static_call_types.h
+++ b/tools/include/linux/static_call_types.h
@@ -39,17 +39,39 @@ struct static_call_site {
 
 #ifdef CONFIG_HAVE_STATIC_CALL
 
+#define __raw_static_call(name)        (&STATIC_CALL_TRAMP(name))
+
+#ifdef CONFIG_HAVE_STATIC_CALL_INLINE
+
 /*
  * __ADDRESSABLE() is used to ensure the key symbol doesn't get stripped from
  * the symbol table so that objtool can reference it when it generates the
  * .static_call_sites section.
  */
+#define __STATIC_CALL_ADDRESSABLE(name) \
+       __ADDRESSABLE(STATIC_CALL_KEY(name))
+
 #define __static_call(name)                                            \
 ({                                                                     \
-       __ADDRESSABLE(STATIC_CALL_KEY(name));                           \
-       &STATIC_CALL_TRAMP(name);                                       \
+       __STATIC_CALL_ADDRESSABLE(name);                                \
+       __raw_static_call(name);                                        \
 })
 
+#else /* !CONFIG_HAVE_STATIC_CALL_INLINE */
+
+#define __STATIC_CALL_ADDRESSABLE(name)
+#define __static_call(name)    __raw_static_call(name)
+
+#endif /* CONFIG_HAVE_STATIC_CALL_INLINE */
+
+#ifdef MODULE
+#define __STATIC_CALL_MOD_ADDRESSABLE(name)
+#define static_call_mod(name)  __raw_static_call(name)
+#else
+#define __STATIC_CALL_MOD_ADDRESSABLE(name) __STATIC_CALL_ADDRESSABLE(name)
+#define static_call_mod(name)  __static_call(name)
+#endif
+
 #define static_call(name)      __static_call(name)
 
 #else
--- a/tools/objtool/check.c
+++ b/tools/objtool/check.c
@@ -502,8 +502,16 @@ static int create_static_call_sections(s
 
                key_sym = find_symbol_by_name(file->elf, tmp);
                if (!key_sym) {
-                       WARN("static_call: can't find static_call_key symbol: 
%s", tmp);
-                       return -1;
+                       if (!module) {
+                               WARN("static_call: can't find static_call_key 
symbol: %s", tmp);
+                               return -1;
+                       }
+                       /*
+                        * For static_call_mod() we allow the key to be the
+                        * trampoline address. This is fixed up in
+                        * static_call_add_module().
+                        */
+                       key_sym = insn->call_dest;
                }
                free(key_name);
 

Reply via email to