On 4/26/22 6:47 PM, Claudio Fontana wrote:
> For the save direction, this helper listens on a unix socket
> which QEMU connects to for multifd migration to files.
> 
> For the restore direction, this helper connects to a unix socket
> QEMU listens at for multifd migration from files.
> 
> The file descriptors are passed as command line parameters.
> 
> Signed-off-by: Claudio Fontana <cfont...@suse.de>
> ---
>  src/libvirt_private.syms  |   1 +
>  src/util/meson.build      |  13 ++
>  src/util/multifd-helper.c | 250 ++++++++++++++++++++++++++++++++++++++
>  src/util/virthread.c      |   5 +
>  src/util/virthread.h      |   1 +
>  5 files changed, 270 insertions(+)
>  create mode 100644 src/util/multifd-helper.c
> 
> diff --git a/src/libvirt_private.syms b/src/libvirt_private.syms
> index 97bfca906b..5f2bee985e 100644
> --- a/src/libvirt_private.syms
> +++ b/src/libvirt_private.syms
> @@ -3427,6 +3427,7 @@ virThreadCreateFull;
>  virThreadID;
>  virThreadIsSelf;
>  virThreadJoin;
> +virThreadJoinRet;
>  virThreadMaxName;
>  virThreadSelf;
>  virThreadSelfID;
> diff --git a/src/util/meson.build b/src/util/meson.build
> index 58001a1699..8ea74ff9e8 100644
> --- a/src/util/meson.build
> +++ b/src/util/meson.build
> @@ -179,6 +179,12 @@ io_helper_sources = [
>    'runio.h',
>  ]
>  
> +multifd_helper_sources = [
> +  'multifd-helper.c',
> +  'runio.c',
> +  'runio.h',
> +]
> +
>  virt_util_lib = static_library(
>    'virt_util',
>    [
> @@ -216,6 +222,13 @@ if conf.has('WITH_LIBVIRTD')
>        dtrace_gen_headers,
>      ],
>    }
> +  virt_helpers += {
> +    'name': 'libvirt_multifd_helper',
> +    'sources': [
> +      files(multifd_helper_sources),
> +      dtrace_gen_headers,
> +    ],
> +  }
>  endif
>  
>  util_inc_dir = include_directories('.')
> diff --git a/src/util/multifd-helper.c b/src/util/multifd-helper.c
> new file mode 100644
> index 0000000000..37e61a3a4d
> --- /dev/null
> +++ b/src/util/multifd-helper.c
> @@ -0,0 +1,250 @@
> +/*
> + * multifd-helper.c: listens on Unix socket to perform I/O to multiple files
> + *
> + * Copyright (C) 2022 SUSE LLC
> + *
> + * This library is free software; you can redistribute it and/or
> + * modify it under the terms of the GNU Lesser General Public
> + * License as published by the Free Software Foundation; either
> + * version 2.1 of the License, or (at your option) any later version.
> + *
> + * This library is distributed in the hope that it will be useful,
> + * but WITHOUT ANY WARRANTY; without even the implied warranty of
> + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
> + * Lesser General Public License for more details.
> + *
> + * You should have received a copy of the GNU Lesser General Public
> + * License along with this library.  If not, see
> + * <http://www.gnu.org/licenses/>.
> + *
> + * This has been written to support QEMU multifd migration to file,
> + * allowing better use of cpu resources to speed up the save/restore.
> + */
> +
> +#include <config.h>
> +
> +#include <unistd.h>
> +#include <fcntl.h>
> +#include <stdlib.h>
> +#include <sys/types.h>
> +#include <sys/stat.h>
> +#include <sys/socket.h>
> +#include <sys/un.h>
> +
> +#include "virthread.h"
> +#include "virfile.h"
> +#include "virerror.h"
> +#include "virstring.h"
> +#include "virgettext.h"
> +#include "runio.h"
> +
> +#define VIR_FROM_THIS VIR_FROM_STORAGE
> +
> +typedef struct _multiFdConnData multiFdConnData;
> +struct _multiFdConnData {
> +    int clientfd;
> +    int filefd;
> +    int oflags;
> +    const char *path;
> +    virThread tid;
> +
> +    off_t total;
> +};
> +
> +typedef struct _multiFdThreadArgs multiFdThreadArgs;
> +struct _multiFdThreadArgs {
> +    int nchannels;
> +    multiFdConnData *conn;     /* contains main fd + nchannels */
> +    const char *sun_path;      /* unix socket name to use for the server */
> +    struct sockaddr_un serv_addr;
> +
> +    off_t total;
> +};
> +
> +static void clientThreadFunc(void *a)
> +{
> +    multiFdConnData *c = a;
> +    c->total = runIO(c->path, c->filefd, c->oflags, c->clientfd, 
> c->clientfd);
> +}
> +
> +static off_t waitClientThreads(multiFdConnData *conn, int n)
> +{
> +    int i;
> +    off_t total = 0;
> +    for (i = 0; i < n; i++) {
> +        multiFdConnData *c = &conn[i];
> +        if (virThreadJoinRet(&c->tid) < 0) {
> +            total = -1;
> +        } else if (total >= 0) {
> +            total += c->total;
> +        }
> +        if (VIR_CLOSE(c->clientfd) < 0) {
> +            total = -1;
> +        }
> +    }
> +    return total;
> +}
> +
> +static void loadThreadFunc(void *a)
> +{
> +    multiFdThreadArgs *args = a;
> +    int i;
> +    args->total = -1;
> +
> +    for (i = 0; i < args->nchannels + 1; i++) {
> +        /* Perform outgoing connections */
> +        multiFdConnData *c = &args->conn[i];
> +        c->clientfd = socket(AF_UNIX, SOCK_STREAM, 0);
> +        if (c->clientfd < 0) {
> +            virReportSystemError(errno, "%s", _("loadThread: socket() 
> failed"));
> +            goto cleanup;
> +        }
> +        if (connect(c->clientfd, (const struct sockaddr *)&args->serv_addr,
> +                    sizeof(struct sockaddr_un)) < 0) {
> +            virReportSystemError(errno, "%s", _("loadThread: connect() 
> failed"));
> +            goto cleanup;
> +        }
> +        if (virThreadCreate(&c->tid, true, &clientThreadFunc, c) < 0) {
> +            virReportSystemError(errno, "%s", _("loadThread: client thread 
> creation failed"));
> +            goto cleanup;
> +        }
> +    }
> +    args->total = waitClientThreads(args->conn, args->nchannels + 1);
> +
> + cleanup:
> +    for (i = 0; i < args->nchannels + 1; i++) {
> +        multiFdConnData *c = &args->conn[i];
> +        VIR_FORCE_CLOSE(c->clientfd);
> +    }
> +}
> +
> +static void saveThreadFunc(void *a)
> +{
> +    multiFdThreadArgs *args = a;
> +    int i;
> +    const char buf[1] = {'R'};
> +    int sockfd;
> +
> +    if ((sockfd = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) {
> +        virReportSystemError(errno, "%s", _("saveThread: socket() failed"));
> +        return;
> +    }
> +    unlink(args->sun_path);
> +    if (bind(sockfd, (struct sockaddr *)&args->serv_addr, 
> sizeof(args->serv_addr)) < 0) {
> +        virReportSystemError(errno, "%s", _("saveThread: bind() failed"));
> +        goto cleanup;
> +    }
> +    if (listen(sockfd, args->nchannels + 1) < 0) {
> +        virReportSystemError(errno, "%s", _("saveThread: listen() failed"));
> +        goto cleanup;
> +    }
> +
> +    /* signal that the server is ready */
> +    if (safewrite(STDOUT_FILENO, &buf, 1) != 1) {
> +        virReportSystemError(errno, "%s", _("saveThread: safewrite failed"));
> +        goto cleanup;
> +    }
> +
> +    for (i = 0; i < args->nchannels + 1; i++) {
> +        /* Wait for incoming connection. */
> +        multiFdConnData *c = &args->conn[i];
> +        if ((c->clientfd = accept(sockfd, NULL, NULL)) < 0) {
> +            virReportSystemError(errno, "%s", _("saveThread: accept() 
> failed"));
> +            goto cleanup;
> +        }
> +        if (virThreadCreate(&c->tid, true, &clientThreadFunc, c) < 0) {
> +            virReportSystemError(errno, "%s", _("saveThread: client thread 
> creation failed"));
> +            goto cleanup;
> +        }
> +    }
> +
> +    args->total = waitClientThreads(args->conn, args->nchannels + 1);
> +
> + cleanup:
> +    for (i = 0; i < args->nchannels + 1; i++) {
> +        multiFdConnData *c = &args->conn[i];
> +        VIR_FORCE_CLOSE(c->clientfd);
> +    }
> +    if (VIR_CLOSE(sockfd) < 0)
> +        args->total = -1;
> +}
> +
> +static const char *program_name;
> +
> +G_GNUC_NORETURN static void
> +usage(int status)
> +{
> +    if (status) {
> +        fprintf(stderr, _("%s: try --help for more details"), program_name);
> +    } else {
> +        fprintf(stderr, _("Usage: %s UNIX_SOCNAME N MAINFD FD0 FD1 ... 
> FDn"), program_name);
> +    }
> +    exit(status);
> +}
> +
> +int
> +main(int argc, char **argv)
> +{
> +    virThread tid;
> +    virThreadFunc func = saveThreadFunc;
> +    multiFdThreadArgs args = { 0 };
> +    int i;
> +
> +    sleep(10);


This of course should not be there, something handy during debugging only.

> +
> +    program_name = argv[0];
> +
> +    if (virGettextInitialize() < 0 ||
> +        virErrorInitialize() < 0) {
> +        fprintf(stderr, _("%s: initialization failed"), program_name);
> +        exit(EXIT_FAILURE);
> +    }
> +
> +    if (argc > 1 && STREQ(argv[1], "--help"))
> +        usage(EXIT_SUCCESS);
> +    if (argc < 4)
> +        usage(EXIT_FAILURE);
> +
> +    args.sun_path = argv[1];
> +    if (virStrToLong_i(argv[2], NULL, 10, &args.nchannels) < 0)
> +        fprintf(stderr, _("%s: malformed number of channels N %s"), 
> program_name, argv[2]);
> +
> +    if (argc < 4 + args.nchannels)
> +        usage(EXIT_FAILURE);
> +
> +    args.conn = g_new0(multiFdConnData, args.nchannels + 1);
> +
> +    for (i = 3; i < 3 + args.nchannels + 1; i++) {
> +        multiFdConnData *c = &args.conn[i - 3];
> +
> +        if (virStrToLong_i(argv[i], NULL, 10, &c->filefd) < 0) {
> +            fprintf(stderr, _("%s: malformed FD %s"), program_name, argv[i]);
> +            usage(EXIT_FAILURE);
> +        }
> +#ifndef F_GETFL
> +#error "multifd-helper requires F_GETFL parameter of fcntl"
> +#endif
> +        c->oflags = fcntl(c->filefd, F_GETFL);
> +        if ((c->oflags & O_ACCMODE) == O_RDONLY) {
> +            func = loadThreadFunc;
> +        }
> +    }
> +
> +    /* initialize server address structure */
> +    memset(&args.serv_addr, 0, sizeof(args.serv_addr));
> +    args.serv_addr.sun_family = AF_UNIX;
> +    strncpy(args.serv_addr.sun_path, args.sun_path, 
> sizeof(args.serv_addr.sun_path) - 1);
> +
> +    if (virThreadCreate(&tid, true, func, &args) < 0) {
> +        virReportSystemError(errno, _("%s: failed to create server thread"), 
> program_name);
> +        exit(EXIT_FAILURE);
> +    }
> +
> +    if (virThreadJoinRet(&tid) < 0)
> +        exit(EXIT_FAILURE);
> +
> +    if (args.total < 0)
> +        exit(EXIT_FAILURE);
> +
> +    exit(EXIT_SUCCESS);
> +}
> diff --git a/src/util/virthread.c b/src/util/virthread.c
> index 5422bb74fd..0f6c6a68fa 100644
> --- a/src/util/virthread.c
> +++ b/src/util/virthread.c
> @@ -348,6 +348,11 @@ void virThreadJoin(virThread *thread)
>      pthread_join(thread->thread, NULL);
>  }
>  
> +int virThreadJoinRet(virThread *thread)
> +{
> +    return pthread_join(thread->thread, NULL);
> +}
> +
>  void virThreadCancel(virThread *thread)
>  {
>      pthread_cancel(thread->thread);
> diff --git a/src/util/virthread.h b/src/util/virthread.h
> index 23abe0b6c9..5cecb9bd8a 100644
> --- a/src/util/virthread.h
> +++ b/src/util/virthread.h
> @@ -89,6 +89,7 @@ int virThreadCreateFull(virThread *thread,
>  void virThreadSelf(virThread *thread);
>  bool virThreadIsSelf(virThread *thread);
>  void virThreadJoin(virThread *thread);
> +int virThreadJoinRet(virThread *thread);
>  
>  size_t virThreadMaxName(void);
>  
> 

Reply via email to