For the save direction, this helper listens on a unix socket
which QEMU connects to for multifd migration to a file.

For the restore direction, this helper connects to a unix socket
QEMU listens at for multifd migration from a file.

The file descriptor is passed as a command line parameter,
and interleaved channels are used to allow reading/writing
to different parts of the file depending on the channel used.

Signed-off-by: Claudio Fontana <cfont...@suse.de>
---
 po/POTFILES               |   1 +
 src/util/meson.build      |  16 ++
 src/util/multifd-helper.c | 359 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 376 insertions(+)
 create mode 100644 src/util/multifd-helper.c

diff --git a/po/POTFILES b/po/POTFILES
index faaba53c8f..97ecbb0ead 100644
--- a/po/POTFILES
+++ b/po/POTFILES
@@ -241,6 +241,7 @@ src/storage_file/storage_source.c
 src/storage_file/storage_source_backingstore.c
 src/test/test_driver.c
 src/util/iohelper.c
+src/util/multifd-helper.c
 src/util/viralloc.c
 src/util/virarptable.c
 src/util/viraudit.c
diff --git a/src/util/meson.build b/src/util/meson.build
index 07ae94631c..2e08ed8745 100644
--- a/src/util/meson.build
+++ b/src/util/meson.build
@@ -179,6 +179,11 @@ io_helper_sources = [
   'virfile.c',
 ]
 
+multifd_helper_sources = [
+  'multifd-helper.c',
+  'virfile.c',
+]
+
 virt_util_lib = static_library(
   'virt_util',
   [
@@ -220,6 +225,17 @@ if conf.has('WITH_LIBVIRTD')
       libutil_dep,
     ],
   }
+  virt_helpers += {
+    'name': 'libvirt_multifd_helper',
+    'sources': [
+      files(multifd_helper_sources),
+      dtrace_gen_headers,
+    ],
+    'deps': [
+      acl_dep,
+      libutil_dep,
+    ],
+  }
 endif
 
 util_inc_dir = include_directories('.')
diff --git a/src/util/multifd-helper.c b/src/util/multifd-helper.c
new file mode 100644
index 0000000000..6d26e2210a
--- /dev/null
+++ b/src/util/multifd-helper.c
@@ -0,0 +1,359 @@
+/*
+ * multifd-helper.c: listens on Unix socket to perform I/O to multiple files
+ *
+ * Copyright (C) 2022 SUSE LLC
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library.  If not, see
+ * <http://www.gnu.org/licenses/>.
+ *
+ * This has been written to support QEMU multifd migration to file,
+ * allowing better use of cpu resources to speed up the save/restore.
+ */
+
+#include <config.h>
+
+#include <unistd.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "virfile.h"
+#include "virerror.h"
+#include "virstring.h"
+#include "virgettext.h"
+
+#define VIR_FROM_THIS VIR_FROM_STORAGE
+
+typedef struct _multiFdConnData multiFdConnData;
+struct _multiFdConnData {
+    int idx;
+    int nchannels;
+    int clientfd;
+    int filefd;
+    int oflags;
+    const char *sun_path;
+    const char *disk_path;
+    off_t total;
+
+    GThread *tid;
+};
+
+typedef struct _multiFdThreadArgs multiFdThreadArgs;
+struct _multiFdThreadArgs {
+    int nchannels;
+    multiFdConnData *conn;     /* contains main fd + nchannels */
+    const char *sun_path;      /* unix socket name to use for the server */
+    const char *disk_path;     /* disk pathname */
+    struct sockaddr_un serv_addr;
+
+    off_t total;
+};
+
+static gpointer clientThreadFunc(void *a)
+{
+    multiFdConnData *c = a;
+    c->total = virFileDiskCopyChannel(c->filefd, c->disk_path, c->clientfd, 
c->sun_path,
+                                      c->idx, c->nchannels, c->total);
+    return &c->total;
+}
+
+static off_t waitClientThreads(multiFdConnData *conn, int n)
+{
+    int idx;
+    off_t total = 0;
+
+    for (idx = 0; idx < n; idx++) {
+        multiFdConnData *c = &conn[idx];
+        off_t *ctotal;
+
+        ctotal = g_thread_join(c->tid);
+        if (*ctotal < (off_t)0) {
+            total = -1;
+        } else if (total >= 0) {
+            total += *ctotal;
+        }
+        if (VIR_CLOSE(c->clientfd) < 0) {
+            total = -1;
+        }
+    }
+    return total;
+}
+
+static gpointer loadThreadFunc(void *a)
+{
+    multiFdThreadArgs *args = a;
+    int idx;
+    args->total = -1;
+
+    for (idx = 0; idx < args->nchannels; idx++) {
+        /* Perform outgoing connections */
+        multiFdConnData *c = &args->conn[idx];
+        c->clientfd = socket(AF_UNIX, SOCK_STREAM, 0);
+        if (c->clientfd < 0) {
+            virReportSystemError(errno, "%s", _("loadThread: socket() 
failed"));
+            goto cleanup;
+        }
+        if (connect(c->clientfd, (const struct sockaddr *)&args->serv_addr,
+                    sizeof(struct sockaddr_un)) < 0) {
+            virReportSystemError(errno, "%s", _("loadThread: connect() 
failed"));
+            goto cleanup;
+        }
+        c->tid = g_thread_new("libvirt_multifd_load", &clientThreadFunc, c);
+    }
+    args->total = waitClientThreads(args->conn, args->nchannels);
+
+ cleanup:
+    for (idx = 0; idx < args->nchannels; idx++) {
+        multiFdConnData *c = &args->conn[idx];
+        VIR_FORCE_CLOSE(c->clientfd);
+    }
+    return &args->total;
+}
+
+static gpointer saveThreadFunc(void *a)
+{
+    multiFdThreadArgs *args = a;
+    int idx;
+    const char buf[1] = {'R'};
+    int sockfd;
+
+    args->total = -1;
+
+    if ((sockfd = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) {
+        virReportSystemError(errno, "%s", _("saveThread: socket() failed"));
+        return &args->total;
+    }
+    unlink(args->sun_path);
+    if (bind(sockfd, (struct sockaddr *)&args->serv_addr, 
sizeof(args->serv_addr)) < 0) {
+        virReportSystemError(errno, "%s", _("saveThread: bind() failed"));
+        goto cleanup;
+    }
+    if (listen(sockfd, args->nchannels) < 0) {
+        virReportSystemError(errno, "%s", _("saveThread: listen() failed"));
+        goto cleanup;
+    }
+    /* signal that the server is ready */
+    if (safewrite(STDOUT_FILENO, &buf, 1) != 1) {
+        virReportSystemError(errno, "%s", _("saveThread: safewrite failed"));
+        goto cleanup;
+    }
+    for (idx = 0; idx < args->nchannels; idx++) {
+        /* Wait for incoming connection. */
+        multiFdConnData *c = &args->conn[idx];
+        if ((c->clientfd = accept(sockfd, NULL, NULL)) < 0) {
+            virReportSystemError(errno, "%s", _("saveThread: accept() 
failed"));
+            goto cleanup;
+        }
+        c->tid = g_thread_new("libvirt_multifd_save", &clientThreadFunc, c);
+    }
+    args->total = waitClientThreads(args->conn, args->nchannels);
+
+ cleanup:
+    for (idx = 0; idx < args->nchannels; idx++) {
+        multiFdConnData *c = &args->conn[idx];
+        VIR_FORCE_CLOSE(c->clientfd);
+    }
+    if (VIR_CLOSE(sockfd) < 0)
+        args->total = -1;
+    return &args->total;
+}
+
+static int readCLIA(int disk_fd, int nchannels, multiFdConnData *conn)
+{
+    int idx;
+    g_autofree void *base = NULL; /* Location to be freed */
+    size_t buflen = virFileDirectAlign(nchannels * 8);
+    int64_t *buf = virFileDirectBufferNew(&base, buflen);
+    ssize_t got = saferead(disk_fd, buf, buflen);
+
+    if (got < buflen)
+        return -1;
+
+    for (idx = 0; idx < nchannels; idx++) {
+        multiFdConnData *c = &conn[idx];
+        c->total = buf[idx];
+    }
+    return 0;
+}
+
+static int writeCLIA(int disk_fd, int nchannels, multiFdConnData *conn)
+{
+    int idx;
+    g_autofree void *base = NULL; /* Location to be freed */
+    size_t buflen = virFileDirectAlign(nchannels * 8);
+    int64_t *buf = virFileDirectBufferNew(&base, buflen);
+
+    for (idx = 0; idx < nchannels; idx++) {
+        multiFdConnData *c = &conn[idx];
+        buf[idx] = c->total;
+    }
+    if (safewrite(disk_fd, buf, buflen) < buflen)
+        return -1;
+    return 0;
+}
+
+static const char *program_name;
+
+G_GNUC_NORETURN static void
+usage(int status)
+{
+    if (status) {
+        fprintf(stderr, _("%s: try --help for more details"), program_name);
+    } else {
+        fprintf(stderr, _("Usage: %s UNIX_SOCNAME DISK_PATHNAME N MAINFD"), 
program_name);
+    }
+    exit(status);
+}
+
+int
+main(int argc, char **argv)
+{
+    GThread *tid;
+    GThreadFunc func;
+    multiFdThreadArgs args = { 0 };
+    multiFdConnData *mc;
+    int idx;
+    off_t clia_off, data_off, *total;
+
+    program_name = argv[0];
+    if (virGettextInitialize() < 0 ||
+        virErrorInitialize() < 0) {
+        fprintf(stderr, _("%s: initialization failed\n"), program_name);
+        exit(EXIT_FAILURE);
+    }
+
+    if (argc > 1 && STREQ(argv[1], "--help"))
+        usage(EXIT_SUCCESS);
+    if (argc < 5)
+        usage(EXIT_FAILURE);
+
+    args.sun_path = argv[1];
+    args.disk_path = argv[2];
+    if (virStrToLong_i(argv[3], NULL, 10, &args.nchannels) < 0) {
+        fprintf(stderr, _("%s: malformed number of channels N %s\n"), 
program_name, argv[3]);
+        usage(EXIT_FAILURE);
+    }
+    /* consider the main channel as just another channel */
+    args.nchannels += 1;
+    args.conn = g_new0(multiFdConnData, args.nchannels);
+
+    /* set main channel connection data */
+    mc = &args.conn[0];
+    mc->idx = 0;
+    mc->nchannels = args.nchannels;
+    if (virStrToLong_i(argv[4], NULL, 10, &mc->filefd) < 0) {
+        fprintf(stderr, _("%s: malformed MAINFD %s\n"), program_name, argv[4]);
+        usage(EXIT_FAILURE);
+    }
+
+#ifndef F_GETFL
+# error "multifd-helper requires F_GETFL parameter of fcntl"
+#endif
+
+    mc->oflags = fcntl(mc->filefd, F_GETFL);
+    mc->sun_path = args.sun_path;
+    mc->disk_path = args.disk_path;
+    clia_off = lseek(mc->filefd, 0, SEEK_CUR);
+    if (clia_off < 0) {
+        fprintf(stderr, _("%s: failed to seek %s\n"), program_name, 
args.disk_path);
+        exit(EXIT_FAILURE);
+    }
+    if ((mc->oflags & O_ACCMODE) == O_RDONLY) {
+        func = loadThreadFunc;
+        /* set totals from the Channel Length Indicators Area */
+        if (readCLIA(mc->filefd, args.nchannels, args.conn) < 0) {
+            fprintf(stderr, _("%s: failed to read CLIA\n"), program_name);
+            exit(EXIT_FAILURE);
+        }
+    } else {
+        func = saveThreadFunc;
+        /* skip Channel Length Indicators Area */
+        if (lseek(mc->filefd, virFileDirectAlign(args.nchannels * 8), 
SEEK_CUR) < 0) {
+            fprintf(stderr, _("%s: failed to seek %s\n"), program_name, 
args.disk_path);
+            exit(EXIT_FAILURE);
+        }
+        mc->total = 0;
+    }
+    if ((data_off = lseek(mc->filefd, 0, SEEK_CUR)) < 0) {
+        fprintf(stderr, _("%s: failed to seek %s\n"), program_name, 
args.disk_path);
+        exit(EXIT_FAILURE);
+    }
+
+    /* initialize channels */
+    for (idx = 1; idx < args.nchannels; idx++) {
+        multiFdConnData *c = &args.conn[idx];
+        c->idx = idx;
+        c->nchannels = args.nchannels;
+        c->oflags = mc->oflags & ~(O_TRUNC | O_CREAT);
+        c->filefd = open(args.disk_path, c->oflags);
+        if (c->filefd < 0) {
+            fprintf(stderr, _("%s: failed to open %s\n"), program_name, 
args.disk_path);
+            exit(EXIT_FAILURE);
+        }
+        c->sun_path = args.sun_path;
+        c->disk_path = args.disk_path;
+        if (mc->total == 0)
+            c->total = 0;
+        if (lseek(c->filefd, data_off, SEEK_SET) < 0) {
+            fprintf(stderr, _("%s: failed to seek %s\n"), program_name, 
args.disk_path);
+            exit(EXIT_FAILURE);
+        }
+    }
+
+    /* initialize server address structure */
+    memset(&args.serv_addr, 0, sizeof(args.serv_addr));
+    args.serv_addr.sun_family = AF_UNIX;
+    virStrcpyStatic(args.serv_addr.sun_path, args.sun_path);
+
+    tid = g_thread_new("libvirt_multifd_func", func, &args);
+
+    total = g_thread_join(tid);
+    if (*total < 0) {
+        exit(EXIT_FAILURE);
+    }
+    if (func == saveThreadFunc) {
+        /* write CLIA */
+        if (lseek(mc->filefd, clia_off, SEEK_SET) < 0) {
+            fprintf(stderr, _("%s: failed to seek %s\n"), program_name, 
args.disk_path);
+            exit(EXIT_FAILURE);
+        }
+        /* set totals into the Channel Length Indicators Area */
+        if (writeCLIA(mc->filefd, args.nchannels, args.conn) < 0) {
+            fprintf(stderr, _("%s: failed to write CLIA\n"), program_name);
+            exit(EXIT_FAILURE);
+        }
+        if (lseek(mc->filefd, 0, SEEK_END) < 0) {
+            fprintf(stderr, _("%s: failed to seek %s\n"), program_name, 
args.disk_path);
+            exit(EXIT_FAILURE);
+        }
+        if (virFileDataSync(mc->filefd) < 0) {
+            if (errno != EINVAL && errno != EROFS) {
+                fprintf(stderr, _("%s: failed to fsyncdata %s\n"), 
program_name, args.disk_path);
+                exit(EXIT_FAILURE);
+            }
+        }
+    }
+    /* close up */
+    for (idx = 0; idx < args.nchannels; idx++) {
+        multiFdConnData *c = &args.conn[idx];
+        if (VIR_CLOSE(c->filefd) < 0) {
+            fprintf(stderr, _("%s: failed to close %s\n"), program_name, 
args.disk_path);
+            exit(EXIT_FAILURE);
+        }
+    }
+    exit(EXIT_SUCCESS);
+}
-- 
2.26.2

Reply via email to