From: Alexander Potapenko <gli...@google.com>

Inserts KFENCE hooks into the SLAB allocator.

We note the addition of the 'orig_size' argument to slab_alloc*()
functions, to be able to pass the originally requested size to KFENCE.
When KFENCE is disabled, there is no additional overhead, since these
functions are __always_inline.

Co-developed-by: Marco Elver <el...@google.com>
Signed-off-by: Marco Elver <el...@google.com>
Signed-off-by: Alexander Potapenko <gli...@google.com>
---
 mm/slab.c        | 46 ++++++++++++++++++++++++++++++++++------------
 mm/slab_common.c |  6 +++++-
 2 files changed, 39 insertions(+), 13 deletions(-)

diff --git a/mm/slab.c b/mm/slab.c
index 3160dff6fd76..30aba06ae02b 100644
--- a/mm/slab.c
+++ b/mm/slab.c
@@ -100,6 +100,7 @@
 #include       <linux/seq_file.h>
 #include       <linux/notifier.h>
 #include       <linux/kallsyms.h>
+#include       <linux/kfence.h>
 #include       <linux/cpu.h>
 #include       <linux/sysctl.h>
 #include       <linux/module.h>
@@ -3206,7 +3207,7 @@ static void *____cache_alloc_node(struct kmem_cache 
*cachep, gfp_t flags,
 }
 
 static __always_inline void *
-slab_alloc_node(struct kmem_cache *cachep, gfp_t flags, int nodeid,
+slab_alloc_node(struct kmem_cache *cachep, gfp_t flags, int nodeid, size_t 
orig_size,
                   unsigned long caller)
 {
        unsigned long save_flags;
@@ -3219,6 +3220,10 @@ slab_alloc_node(struct kmem_cache *cachep, gfp_t flags, 
int nodeid,
        if (unlikely(!cachep))
                return NULL;
 
+       ptr = kfence_alloc(cachep, orig_size, flags);
+       if (unlikely(ptr))
+               goto out_hooks;
+
        cache_alloc_debugcheck_before(cachep, flags);
        local_irq_save(save_flags);
 
@@ -3251,6 +3256,7 @@ slab_alloc_node(struct kmem_cache *cachep, gfp_t flags, 
int nodeid,
        if (unlikely(slab_want_init_on_alloc(flags, cachep)) && ptr)
                memset(ptr, 0, cachep->object_size);
 
+out_hooks:
        slab_post_alloc_hook(cachep, objcg, flags, 1, &ptr);
        return ptr;
 }
@@ -3288,7 +3294,7 @@ __do_cache_alloc(struct kmem_cache *cachep, gfp_t flags)
 #endif /* CONFIG_NUMA */
 
 static __always_inline void *
-slab_alloc(struct kmem_cache *cachep, gfp_t flags, unsigned long caller)
+slab_alloc(struct kmem_cache *cachep, gfp_t flags, size_t orig_size, unsigned 
long caller)
 {
        unsigned long save_flags;
        void *objp;
@@ -3299,6 +3305,10 @@ slab_alloc(struct kmem_cache *cachep, gfp_t flags, 
unsigned long caller)
        if (unlikely(!cachep))
                return NULL;
 
+       objp = kfence_alloc(cachep, orig_size, flags);
+       if (unlikely(objp))
+               goto leave;
+
        cache_alloc_debugcheck_before(cachep, flags);
        local_irq_save(save_flags);
        objp = __do_cache_alloc(cachep, flags);
@@ -3309,6 +3319,7 @@ slab_alloc(struct kmem_cache *cachep, gfp_t flags, 
unsigned long caller)
        if (unlikely(slab_want_init_on_alloc(flags, cachep)) && objp)
                memset(objp, 0, cachep->object_size);
 
+leave:
        slab_post_alloc_hook(cachep, objcg, flags, 1, &objp);
        return objp;
 }
@@ -3414,6 +3425,11 @@ static void cache_flusharray(struct kmem_cache *cachep, 
struct array_cache *ac)
 static __always_inline void __cache_free(struct kmem_cache *cachep, void *objp,
                                         unsigned long caller)
 {
+       if (kfence_free(objp)) {
+               kmemleak_free_recursive(objp, cachep->flags);
+               return;
+       }
+
        /* Put the object into the quarantine, don't touch it for now. */
        if (kasan_slab_free(cachep, objp, _RET_IP_))
                return;
@@ -3479,7 +3495,7 @@ void ___cache_free(struct kmem_cache *cachep, void *objp,
  */
 void *kmem_cache_alloc(struct kmem_cache *cachep, gfp_t flags)
 {
-       void *ret = slab_alloc(cachep, flags, _RET_IP_);
+       void *ret = slab_alloc(cachep, flags, cachep->object_size, _RET_IP_);
 
        trace_kmem_cache_alloc(_RET_IP_, ret,
                               cachep->object_size, cachep->size, flags);
@@ -3512,7 +3528,7 @@ int kmem_cache_alloc_bulk(struct kmem_cache *s, gfp_t 
flags, size_t size,
 
        local_irq_disable();
        for (i = 0; i < size; i++) {
-               void *objp = __do_cache_alloc(s, flags);
+               void *objp = kfence_alloc(s, s->object_size, flags) ?: 
__do_cache_alloc(s, flags);
 
                if (unlikely(!objp))
                        goto error;
@@ -3545,7 +3561,7 @@ kmem_cache_alloc_trace(struct kmem_cache *cachep, gfp_t 
flags, size_t size)
 {
        void *ret;
 
-       ret = slab_alloc(cachep, flags, _RET_IP_);
+       ret = slab_alloc(cachep, flags, size, _RET_IP_);
 
        ret = kasan_kmalloc(cachep, ret, size, flags);
        trace_kmalloc(_RET_IP_, ret,
@@ -3571,7 +3587,7 @@ EXPORT_SYMBOL(kmem_cache_alloc_trace);
  */
 void *kmem_cache_alloc_node(struct kmem_cache *cachep, gfp_t flags, int nodeid)
 {
-       void *ret = slab_alloc_node(cachep, flags, nodeid, _RET_IP_);
+       void *ret = slab_alloc_node(cachep, flags, nodeid, cachep->object_size, 
_RET_IP_);
 
        trace_kmem_cache_alloc_node(_RET_IP_, ret,
                                    cachep->object_size, cachep->size,
@@ -3589,7 +3605,7 @@ void *kmem_cache_alloc_node_trace(struct kmem_cache 
*cachep,
 {
        void *ret;
 
-       ret = slab_alloc_node(cachep, flags, nodeid, _RET_IP_);
+       ret = slab_alloc_node(cachep, flags, nodeid, size, _RET_IP_);
 
        ret = kasan_kmalloc(cachep, ret, size, flags);
        trace_kmalloc_node(_RET_IP_, ret,
@@ -3650,7 +3666,7 @@ static __always_inline void *__do_kmalloc(size_t size, 
gfp_t flags,
        cachep = kmalloc_slab(size, flags);
        if (unlikely(ZERO_OR_NULL_PTR(cachep)))
                return cachep;
-       ret = slab_alloc(cachep, flags, caller);
+       ret = slab_alloc(cachep, flags, size, caller);
 
        ret = kasan_kmalloc(cachep, ret, size, flags);
        trace_kmalloc(caller, ret,
@@ -4138,18 +4154,24 @@ void __check_heap_object(const void *ptr, unsigned long 
n, struct page *page,
                         bool to_user)
 {
        struct kmem_cache *cachep;
-       unsigned int objnr;
+       unsigned int objnr = 0;
        unsigned long offset;
+       bool is_kfence = is_kfence_address(ptr);
 
        ptr = kasan_reset_tag(ptr);
 
        /* Find and validate object. */
        cachep = page->slab_cache;
-       objnr = obj_to_index(cachep, page, (void *)ptr);
-       BUG_ON(objnr >= cachep->num);
+       if (!is_kfence) {
+               objnr = obj_to_index(cachep, page, (void *)ptr);
+               BUG_ON(objnr >= cachep->num);
+       }
 
        /* Find offset within object. */
-       offset = ptr - index_to_obj(cachep, page, objnr) - obj_offset(cachep);
+       if (is_kfence_address(ptr))
+               offset = ptr - kfence_object_start(ptr);
+       else
+               offset = ptr - index_to_obj(cachep, page, objnr) - 
obj_offset(cachep);
 
        /* Allow address range falling entirely within usercopy region. */
        if (offset >= cachep->useroffset &&
diff --git a/mm/slab_common.c b/mm/slab_common.c
index f9ccd5dc13f3..6e35e273681a 100644
--- a/mm/slab_common.c
+++ b/mm/slab_common.c
@@ -12,6 +12,7 @@
 #include <linux/memory.h>
 #include <linux/cache.h>
 #include <linux/compiler.h>
+#include <linux/kfence.h>
 #include <linux/module.h>
 #include <linux/cpu.h>
 #include <linux/uaccess.h>
@@ -448,6 +449,9 @@ static int shutdown_cache(struct kmem_cache *s)
        /* free asan quarantined objects */
        kasan_cache_shutdown(s);
 
+       if (!kfence_shutdown_cache(s))
+               return -EBUSY;
+
        if (__kmem_cache_shutdown(s) != 0)
                return -EBUSY;
 
@@ -1171,7 +1175,7 @@ size_t ksize(const void *objp)
        if (unlikely(ZERO_OR_NULL_PTR(objp)) || !__kasan_check_read(objp, 1))
                return 0;
 
-       size = __ksize(objp);
+       size = kfence_ksize(objp) ?: __ksize(objp);
        /*
         * We assume that ksize callers could use whole allocated area,
         * so we need to unpoison this area.
-- 
2.28.0.526.ge36021eeef-goog

Reply via email to