The alloc_anon() function calls malloc() without checking for a NULL
return. If memory allocation fails, a NULL pointer dereference will
occur when accessing the buffer.

Add proper error handling to return -1 when malloc() fails in all
four alloc_anon variants:
- alloc_anon()
- alloc_anon_50M_check()
- alloc_anon_noexit()
- alloc_anon_50M_check_swap()

Signed-off-by: Hongfu Li <[email protected]>
Reviewed-by: Vishal Moola <[email protected]>
Reviewed-by: Muchun Song <[email protected]>
---
v2:
- Refactored repeated malloc() patterns in alloc_anon_* into a helper
  function.
---
 .../selftests/cgroup/test_memcontrol.c        | 53 ++++++++++---------
 1 file changed, 27 insertions(+), 26 deletions(-)

diff --git a/tools/testing/selftests/cgroup/test_memcontrol.c 
b/tools/testing/selftests/cgroup/test_memcontrol.c
index b43da9bc20c4..21aedb35cc12 100644
--- a/tools/testing/selftests/cgroup/test_memcontrol.c
+++ b/tools/testing/selftests/cgroup/test_memcontrol.c
@@ -55,15 +55,31 @@ int alloc_pagecache(int fd, size_t size)
        return -1;
 }
 
-int alloc_anon(const char *cgroup, void *arg)
+static char *alloc_and_populate_anon(size_t size)
 {
-       size_t size = (unsigned long)arg;
        char *buf, *ptr;
 
        buf = malloc(size);
+       if (buf == NULL) {
+               fprintf(stderr, "malloc() failed\n");
+               return NULL;
+       }
+
        for (ptr = buf; ptr < buf + size; ptr += PAGE_SIZE)
                *ptr = 0;
 
+       return buf;
+}
+
+int alloc_anon(const char *cgroup, void *arg)
+{
+       size_t size = (unsigned long)arg;
+       char *buf;
+
+       buf = alloc_and_populate_anon(size);
+       if (!buf)
+               return -1;
+
        free(buf);
        return 0;
 }
@@ -174,18 +190,13 @@ static int test_memcg_subtree_control(const char *root)
 static int alloc_anon_50M_check(const char *cgroup, void *arg)
 {
        size_t size = MB(50);
-       char *buf, *ptr;
+       char *buf;
        long anon, current;
        int ret = -1;
 
-       buf = malloc(size);
-       if (buf == NULL) {
-               fprintf(stderr, "malloc() failed\n");
+       buf = alloc_and_populate_anon(size);
+       if (!buf)
                return -1;
-       }
-
-       for (ptr = buf; ptr < buf + size; ptr += PAGE_SIZE)
-               *ptr = 0;
 
        current = cg_read_long(cgroup, "memory.current");
        if (current < size)
@@ -406,16 +417,11 @@ static int alloc_anon_noexit(const char *cgroup, void 
*arg)
 {
        int ppid = getppid();
        size_t size = (unsigned long)arg;
-       char *buf, *ptr;
+       char *buf;
 
-       buf = malloc(size);
-       if (buf == NULL) {
-               fprintf(stderr, "malloc() failed\n");
+       buf = alloc_and_populate_anon(size);
+       if (!buf)
                return -1;
-       }
-
-       for (ptr = buf; ptr < buf + size; ptr += PAGE_SIZE)
-               *ptr = 0;
 
        while (getppid() == ppid)
                sleep(1);
@@ -990,18 +996,13 @@ static int alloc_anon_50M_check_swap(const char *cgroup, 
void *arg)
 {
        long mem_max = (long)arg;
        size_t size = MB(50);
-       char *buf, *ptr;
+       char *buf;
        long mem_current, swap_current;
        int ret = -1;
 
-       buf = malloc(size);
-       if (buf == NULL) {
-               fprintf(stderr, "malloc() failed\n");
+       buf = alloc_and_populate_anon(size);
+       if (!buf)
                return -1;
-       }
-
-       for (ptr = buf; ptr < buf + size; ptr += PAGE_SIZE)
-               *ptr = 0;
 
        mem_current = cg_read_long(cgroup, "memory.current");
        if (!mem_current || !values_close(mem_current, mem_max, 3))
-- 
2.25.1


Reply via email to