This patch ensures that we always check the netlink payload length
in audit_receive_msg() before we take any action on the payload
itself.

Cc: sta...@vger.kernel.org
Reported-by: syzbot+399c44bf1f43b8747...@syzkaller.appspotmail.com
Reported-by: syzbot+e4b12d8d202701f08...@syzkaller.appspotmail.com
Signed-off-by: Paul Moore <p...@paul-moore.com>
---
 kernel/audit.c |   43 +++++++++++++++++++++++--------------------
 1 file changed, 23 insertions(+), 20 deletions(-)

diff --git a/kernel/audit.c b/kernel/audit.c
index 17b0d523afb3..6e8b176bdb68 100644
--- a/kernel/audit.c
+++ b/kernel/audit.c
@@ -1101,13 +1101,11 @@ static void audit_log_feature_change(int which, u32 
old_feature, u32 new_feature
        audit_log_end(ab);
 }
 
-static int audit_set_feature(struct sk_buff *skb)
+static int audit_set_feature(struct audit_features *uaf)
 {
-       struct audit_features *uaf;
        int i;
 
        BUILD_BUG_ON(AUDIT_LAST_FEATURE + 1 > ARRAY_SIZE(audit_feature_names));
-       uaf = nlmsg_data(nlmsg_hdr(skb));
 
        /* if there is ever a version 2 we should handle that here */
 
@@ -1175,6 +1173,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
 {
        u32                     seq;
        void                    *data;
+       int                     data_len;
        int                     err;
        struct audit_buffer     *ab;
        u16                     msg_type = nlh->nlmsg_type;
@@ -1188,6 +1187,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
 
        seq  = nlh->nlmsg_seq;
        data = nlmsg_data(nlh);
+       data_len = nlmsg_len(nlh);
 
        switch (msg_type) {
        case AUDIT_GET: {
@@ -1211,7 +1211,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
                struct audit_status     s;
                memset(&s, 0, sizeof(s));
                /* guard against past and future API changes */
-               memcpy(&s, data, min_t(size_t, sizeof(s), nlmsg_len(nlh)));
+               memcpy(&s, data, min_t(size_t, sizeof(s), data_len));
                if (s.mask & AUDIT_STATUS_ENABLED) {
                        err = audit_set_enabled(s.enabled);
                        if (err < 0)
@@ -1314,11 +1314,14 @@ static int audit_receive_msg(struct sk_buff *skb, 
struct nlmsghdr *nlh)
                if (err)
                        return err;
                break;
-       case AUDIT_SET_FEATURE:
-               err = audit_set_feature(skb);
+       case AUDIT_SET_FEATURE: {
+               if (data_len < sizeof(struct audit_features))
+                       return -EINVAL;
+               err = audit_set_feature(data);
                if (err)
                        return err;
                break;
+       }
        case AUDIT_USER:
        case AUDIT_FIRST_USER_MSG ... AUDIT_LAST_USER_MSG:
        case AUDIT_FIRST_USER_MSG2 ... AUDIT_LAST_USER_MSG2:
@@ -1327,6 +1330,8 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
 
                err = audit_filter(msg_type, AUDIT_FILTER_USER);
                if (err == 1) { /* match or error */
+                       char *str = data;
+
                        err = 0;
                        if (msg_type == AUDIT_USER_TTY) {
                                err = tty_audit_push();
@@ -1334,26 +1339,24 @@ static int audit_receive_msg(struct sk_buff *skb, 
struct nlmsghdr *nlh)
                                        break;
                        }
                        audit_log_user_recv_msg(&ab, msg_type);
-                       if (msg_type != AUDIT_USER_TTY)
+                       if (msg_type != AUDIT_USER_TTY) {
+                               /* ensure NULL termination */
+                               str[data_len - 1] = '\0';
                                audit_log_format(ab, " msg='%.*s'",
                                                 AUDIT_MESSAGE_TEXT_MAX,
-                                                (char *)data);
-                       else {
-                               int size;
-
+                                                str);
+                       } else {
                                audit_log_format(ab, " data=");
-                               size = nlmsg_len(nlh);
-                               if (size > 0 &&
-                                   ((unsigned char *)data)[size - 1] == '\0')
-                                       size--;
-                               audit_log_n_untrustedstring(ab, data, size);
+                               if (data_len > 0 && str[data_len - 1] == '\0')
+                                       data_len--;
+                               audit_log_n_untrustedstring(ab, data, data_len);
                        }
                        audit_log_end(ab);
                }
                break;
        case AUDIT_ADD_RULE:
        case AUDIT_DEL_RULE:
-               if (nlmsg_len(nlh) < sizeof(struct audit_rule_data))
+               if (data_len < sizeof(struct audit_rule_data))
                        return -EINVAL;
                if (audit_enabled == AUDIT_LOCKED) {
                        audit_log_common_recv_msg(audit_context(), &ab,
@@ -1365,7 +1368,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
                        audit_log_end(ab);
                        return -EPERM;
                }
-               err = audit_rule_change(msg_type, seq, data, nlmsg_len(nlh));
+               err = audit_rule_change(msg_type, seq, data, data_len);
                break;
        case AUDIT_LIST_RULES:
                err = audit_list_rules_send(skb, seq);
@@ -1380,7 +1383,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
        case AUDIT_MAKE_EQUIV: {
                void *bufp = data;
                u32 sizes[2];
-               size_t msglen = nlmsg_len(nlh);
+               size_t msglen = data_len;
                char *old, *new;
 
                err = -EINVAL;
@@ -1456,7 +1459,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct 
nlmsghdr *nlh)
 
                memset(&s, 0, sizeof(s));
                /* guard against past and future API changes */
-               memcpy(&s, data, min_t(size_t, sizeof(s), nlmsg_len(nlh)));
+               memcpy(&s, data, min_t(size_t, sizeof(s), data_len));
                /* check if new data is valid */
                if ((s.enabled != 0 && s.enabled != 1) ||
                    (s.log_passwd != 0 && s.log_passwd != 1))

--
Linux-audit mailing list
Linux-audit@redhat.com
https://www.redhat.com/mailman/listinfo/linux-audit

Reply via email to