Dexuan Cui <de...@microsoft.com> writes:

> Hyper-V Sockets (hv_sock) supplies a byte-stream based communication
> mechanism between the host and the guest. It's somewhat like TCP over
> VMBus, but the transportation layer (VMBus) is much simpler than IP.
>
> With Hyper-V Sockets, applications between the host and the guest can talk
> to each other directly by the traditional BSD-style socket APIs.
>
> Hyper-V Sockets is only available on new Windows hosts, like Windows Server
> 2016. More info is in this article "Make your own integration services":
> https://msdn.microsoft.com/en-us/virtualization/hyperv_on_windows/develop/make_mgmt_service
>
> The patch implements the necessary support in the guest side by introducing
> a new socket address family AF_HYPERV.
>
> Signed-off-by: Dexuan Cui <de...@microsoft.com>
> Cc: "K. Y. Srinivasan" <k...@microsoft.com>
> Cc: Haiyang Zhang <haiya...@microsoft.com>
> Cc: Vitaly Kuznetsov <vkuzn...@redhat.com>

Some comments below. The vast majority of them are really minor, the
only thing which bothers me a little bit is WARN() in hvsock_sendmsg()
which I think shouldn't be there. But I may have missed something.

I didn't do any tests for the code.

> Cc: Cathy Avery <cav...@redhat.com>
> ---
>
> You can also get the patch here (2764221d):
> https://github.com/dcui/linux/commits/decui/hv_sock/net-next/20160708_v15
>
> For the change log before v12, please see https://lkml.org/lkml/2016/5/15/31
>
> In v12, the changes are mainly the following:
>
> 1) remove the module params as David suggested.
>
> 2) use 5 exact pages for VMBus send/recv rings, respectively.
> The host side's design of the feature requires 5 exact pages for recv/send
> rings respectively -- this is suboptimal considering memory consumption,
> however unluckily we have to live with it, before the host comes up with
> a new design in the future. :-(
>
> 3) remove the per-connection static send/recv buffers
> Instead, we allocate and free the buffers dynamically only when we recv/send
> data. This means: when a connection is idle, no memory is consumed as
> recv/send buffers at all.
>
> In v13:
> I return ENOMEM on buffer alllocation failure
>
>    Actually "man read/write" says "Other errors may occur, depending on the
> object connected to fd". "man send/recv" indeed lists ENOMEM.
>    Considering AF_HYPERV is a new socket type, ENOMEM seems OK here.
>    In the long run, I think we should add a new API in the VMBus driver,
> allowing data copy from VMBus ringbuffer into user mode buffer directly.
> This way, we can even eliminate this temporary buffer.
>
> In v14:
> fix some coding style issues pointed out by David.
>
> In v15:
> Just some stylistic changes addressing comments from Joe Perches and
> Olaf Hering -- thank you!
> - add a GPL blurb.
> - define a new macro PAGE_SIZE_4K and use it to replace PAGE_SIZE
> - change sk_to_hvsock/hvsock_to_sk() from macros to inline functions
> - remove a not-very-useful pr_err()
> - fix some typos in comment and coding style issues.
>
> Looking forward to your comments!
>
>  MAINTAINERS                 |    2 +
>  include/linux/hyperv.h      |   13 +
>  include/linux/socket.h      |    4 +-
>  include/net/af_hvsock.h     |   78 +++
>  include/uapi/linux/hyperv.h |   24 +
>  net/Kconfig                 |    1 +
>  net/Makefile                |    1 +
>  net/hv_sock/Kconfig         |   10 +
>  net/hv_sock/Makefile        |    3 +
>  net/hv_sock/af_hvsock.c     | 1523 
> +++++++++++++++++++++++++++++++++++++++++++
>  10 files changed, 1658 insertions(+), 1 deletion(-)
>
> diff --git a/MAINTAINERS b/MAINTAINERS
> index 50f69ba..6eaa26f 100644
> --- a/MAINTAINERS
> +++ b/MAINTAINERS
> @@ -5514,7 +5514,9 @@ F:      drivers/pci/host/pci-hyperv.c
>  F:   drivers/net/hyperv/
>  F:   drivers/scsi/storvsc_drv.c
>  F:   drivers/video/fbdev/hyperv_fb.c
> +F:   net/hv_sock/
>  F:   include/linux/hyperv.h
> +F:   include/net/af_hvsock.h
>  F:   tools/hv/
>  F:   Documentation/ABI/stable/sysfs-bus-vmbus
>
> diff --git a/include/linux/hyperv.h b/include/linux/hyperv.h
> index 50f493e..1cda6ea5 100644
> --- a/include/linux/hyperv.h
> +++ b/include/linux/hyperv.h
> @@ -1508,5 +1508,18 @@ static inline void commit_rd_index(struct 
> vmbus_channel *channel)
>               vmbus_set_event(channel);
>  }
>
> +struct vmpipe_proto_header {
> +     u32 pkt_type;
> +     u32 data_size;
> +};
> +
> +#define HVSOCK_HEADER_LEN    (sizeof(struct vmpacket_descriptor) + \
> +                              sizeof(struct vmpipe_proto_header))
> +
> +/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write() */
> +#define PREV_INDICES_LEN     (sizeof(u64))
>
> +#define HVSOCK_PKT_LEN(payload_len)  (HVSOCK_HEADER_LEN + \
> +                                     ALIGN((payload_len), 8) + \
> +                                     PREV_INDICES_LEN)
>  #endif /* _HYPERV_H */
> diff --git a/include/linux/socket.h b/include/linux/socket.h
> index b5cc5a6..0b68b58 100644
> --- a/include/linux/socket.h
> +++ b/include/linux/socket.h
> @@ -202,8 +202,9 @@ struct ucred {
>  #define AF_VSOCK     40      /* vSockets                     */
>  #define AF_KCM               41      /* Kernel Connection Multiplexor*/
>  #define AF_QIPCRTR   42      /* Qualcomm IPC Router          */
> +#define AF_HYPERV    43      /* Hyper-V Sockets              */
>
> -#define AF_MAX               43      /* For now.. */
> +#define AF_MAX               44      /* For now.. */
>
>  /* Protocol families, same as address families. */
>  #define PF_UNSPEC    AF_UNSPEC
> @@ -251,6 +252,7 @@ struct ucred {
>  #define PF_VSOCK     AF_VSOCK
>  #define PF_KCM               AF_KCM
>  #define PF_QIPCRTR   AF_QIPCRTR
> +#define PF_HYPERV    AF_HYPERV
>  #define PF_MAX               AF_MAX
>
>  /* Maximum queue length specifiable by listen.  */
> diff --git a/include/net/af_hvsock.h b/include/net/af_hvsock.h
> new file mode 100644
> index 0000000..e7a8a3a
> --- /dev/null
> +++ b/include/net/af_hvsock.h
> @@ -0,0 +1,78 @@
> +#ifndef __AF_HVSOCK_H__
> +#define __AF_HVSOCK_H__
> +
> +#include <linux/kernel.h>
> +#include <linux/hyperv.h>
> +#include <net/sock.h>
> +
> +/* The host side's design of the feature requires 5 exact 4KB pages for
> + * recv/send rings respectively -- this is suboptimal considering memory
> + * consumption, however unluckily we have to live with it, before the
> + * host comes up with a better design in the future.
> + */
> +#define PAGE_SIZE_4K         4096
> +#define RINGBUFFER_HVSOCK_RCV_SIZE (PAGE_SIZE_4K * 5)
> +#define RINGBUFFER_HVSOCK_SND_SIZE (PAGE_SIZE_4K * 5)
> +
> +/* The MTU is 16KB per the host side's design.
> + * In future, the buffer can be elimiated when we switch to use the coming
> + * new VMBus ringbuffer "in-place consumption" APIs, by which we can
> + * directly copy data from VMBus ringbuffer into the userspace buffer.
> + */
> +#define HVSOCK_MTU_SIZE              (1024 * 16)
> +struct hvsock_recv_buf {
> +     unsigned int data_len;
> +     unsigned int data_offset;
> +
> +     struct vmpipe_proto_header hdr;
> +     u8 buf[HVSOCK_MTU_SIZE];
> +};
> +
> +/* In the VM, actually we can send up to HVSOCK_MTU_SIZE bytes of payload,
> + * but for now let's use a smaller size to minimize the dynamically-allocated
> + * buffer. Note: the buffer can be elimiated in future when we add new VMBus
> + * ringbuffer APIs that allow us to directly copy data from userspace buf to
> + * VMBus ringbuffer.
> + */
> +#define HVSOCK_MAX_SND_SIZE_BY_VM (1024 * 4)
> +struct hvsock_send_buf {
> +     struct vmpipe_proto_header hdr;
> +     u8 buf[HVSOCK_MAX_SND_SIZE_BY_VM];
> +};
> +
> +struct hvsock_sock {
> +     /* sk must be the first member. */
> +     struct sock sk;
> +
> +     struct sockaddr_hv local_addr;
> +     struct sockaddr_hv remote_addr;
> +
> +     /* protected by the global hvsock_mutex */
> +     struct list_head bound_list;
> +     struct list_head connected_list;
> +
> +     struct list_head accept_queue;
> +     /* used by enqueue and dequeue */
> +     struct mutex accept_queue_mutex;
> +
> +     struct delayed_work dwork;
> +
> +     u32 peer_shutdown;
> +
> +     struct vmbus_channel *channel;
> +
> +     struct hvsock_send_buf *send;
> +     struct hvsock_recv_buf *recv;
> +};
> +
> +static inline struct hvsock_sock *sk_to_hvsock(struct sock *sk)
> +{
> +     return (struct hvsock_sock *)sk;
> +}
> +
> +static inline struct sock *hvsock_to_sk(struct hvsock_sock *hvsk)
> +{
> +     return (struct sock *)hvsk;
> +}
> +
> +#endif /* __AF_HVSOCK_H__ */
> diff --git a/include/uapi/linux/hyperv.h b/include/uapi/linux/hyperv.h
> index e347b24..d942996 100644
> --- a/include/uapi/linux/hyperv.h
> +++ b/include/uapi/linux/hyperv.h
> @@ -26,6 +26,7 @@
>  #define _UAPI_HYPERV_H
>
>  #include <linux/uuid.h>
> +#include <linux/socket.h>
>
>  /*
>   * Framework version for util services.
> @@ -396,4 +397,27 @@ struct hv_kvp_ip_msg {
>       struct hv_kvp_ipaddr_value      kvp_ip_val;
>  } __attribute__((packed));
>
> +/* This is the address format of Hyper-V Sockets.
> + * Note: here we just borrow the kernel's built-in type uuid_le. When
> + * an application calls bind() or connect(), the 2 members of struct
> + * sockaddr_hv must be of GUID.
> + * The GUID format differs from the UUID format only in the byte order of
> + * the first 3 fields. Refer to:
> + * https://en.wikipedia.org/wiki/Globally_unique_identifier
> + */
> +#define guid_t uuid_le
> +struct sockaddr_hv {
> +     __kernel_sa_family_t    shv_family;  /* Address family          */
> +     u16             reserved;            /* Must be Zero            */
> +     guid_t          shv_vm_id;           /* VM ID                   */
> +     guid_t          shv_service_id;      /* Service ID              */
> +};

(sorry if it was already discussed before and I missed it)

I'm not sure it is worth it to introduce a new 'guid_t' type here, we
may want to rename

shv_vm_id -> shv_vm_guid
shv_service_id -> shv_service_guid

and use uuid_le type.

> +
> +#define SHV_VMID_GUEST       NULL_UUID_LE
> +#define SHV_VMID_HOST        NULL_UUID_LE
> +
> +#define SHV_SERVICE_ID_ANY   NULL_UUID_LE
> +
> +#define SHV_PROTO_RAW                1
> +
>  #endif /* _UAPI_HYPERV_H */
> diff --git a/net/Kconfig b/net/Kconfig
> index ff40562..0dacc11 100644
> --- a/net/Kconfig
> +++ b/net/Kconfig
> @@ -231,6 +231,7 @@ source "net/dns_resolver/Kconfig"
>  source "net/batman-adv/Kconfig"
>  source "net/openvswitch/Kconfig"
>  source "net/vmw_vsock/Kconfig"
> +source "net/hv_sock/Kconfig"
>  source "net/netlink/Kconfig"
>  source "net/mpls/Kconfig"
>  source "net/hsr/Kconfig"
> diff --git a/net/Makefile b/net/Makefile
> index bdd1455..ec175dd 100644
> --- a/net/Makefile
> +++ b/net/Makefile
> @@ -70,6 +70,7 @@ obj-$(CONFIG_BATMAN_ADV)    += batman-adv/
>  obj-$(CONFIG_NFC)            += nfc/
>  obj-$(CONFIG_OPENVSWITCH)    += openvswitch/
>  obj-$(CONFIG_VSOCKETS)       += vmw_vsock/
> +obj-$(CONFIG_HYPERV_SOCK)    += hv_sock/
>  obj-$(CONFIG_MPLS)           += mpls/
>  obj-$(CONFIG_HSR)            += hsr/
>  ifneq ($(CONFIG_NET_SWITCHDEV),)
> diff --git a/net/hv_sock/Kconfig b/net/hv_sock/Kconfig
> new file mode 100644
> index 0000000..1f41848
> --- /dev/null
> +++ b/net/hv_sock/Kconfig
> @@ -0,0 +1,10 @@
> +config HYPERV_SOCK
> +     tristate "Hyper-V Sockets"
> +     depends on HYPERV
> +     default m if HYPERV
> +     help
> +       Hyper-V Sockets is somewhat like TCP over VMBus, allowing
> +       communication between Linux guest and Hyper-V host without TCP/IP.
> +

I know it's hard to come up with a simple description but I'd rather
describe is as "Socket interface for high speed communication between
Linux guest and Hyper-V host over VMBus."

> +       To compile this driver as a module, choose M here: the module
> +       will be called hv_sock.
> diff --git a/net/hv_sock/Makefile b/net/hv_sock/Makefile
> new file mode 100644
> index 0000000..716c012
> --- /dev/null
> +++ b/net/hv_sock/Makefile
> @@ -0,0 +1,3 @@
> +obj-$(CONFIG_HYPERV_SOCK) += hv_sock.o
> +
> +hv_sock-y += af_hvsock.o
> diff --git a/net/hv_sock/af_hvsock.c b/net/hv_sock/af_hvsock.c
> new file mode 100644
> index 0000000..f339f38
> --- /dev/null
> +++ b/net/hv_sock/af_hvsock.c
> @@ -0,0 +1,1523 @@
> +/*
> + * Hyper-V Sockets -- a socket-based communication channel between the
> + * Hyper-V host and the virtual machines running on it.
> + *
> + * Copyright (c) 2016 Microsoft Corporation.
> + *
> + * All rights reserved.
> + *
> + * Redistribution and use in source and binary forms, with or without
> + * modification, are permitted provided that the following conditions
> + * are met:
> + *
> + * 1. Redistributions of source code must retain the above copyright
> + *    notice, this list of conditions and the following disclaimer.
> + * 2. Redistributions in binary form must reproduce the above copyright
> + *    notice, this list of conditions and the following disclaimer in the
> + *    documentation and/or other materials provided with the distribution.
> + * 3. The name of the author may not be used to endorse or promote
> + *    products derived from this software without specific prior written
> + *    permission.
> + *
> + * Alternatively, this software may be distributed under the terms of the
> + * GNU General Public License ("GPL") version 2 as published by the Free
> + * Software Foundation.
> + *
> + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
> + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
> + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
> + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
> + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
> + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
> + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
> + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
> + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
> + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
> + * POSSIBILITY OF SUCH DAMAGE.
> + */
> +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
> +
> +#include <linux/init.h>
> +#include <linux/module.h>
> +#include <linux/vmalloc.h>
> +#include <net/af_hvsock.h>
> +
> +static struct proto hvsock_proto = {
> +     .name = "HV_SOCK",
> +     .owner = THIS_MODULE,
> +     .obj_size = sizeof(struct hvsock_sock),
> +};
> +
> +#define SS_LISTEN 255
> +
> +static LIST_HEAD(hvsock_bound_list);
> +static LIST_HEAD(hvsock_connected_list);
> +static DEFINE_MUTEX(hvsock_mutex);
> +
> +static bool uuid_equals(uuid_le u1, uuid_le u2)
> +{
> +     return !uuid_le_cmp(u1, u2);
> +}

why not use uuid_le_cmp directly?

> +
> +static struct sock *hvsock_find_bound_socket(const struct sockaddr_hv *addr)
> +{
> +     struct hvsock_sock *hvsk;
> +
> +     list_for_each_entry(hvsk, &hvsock_bound_list, bound_list) {
> +             if (uuid_equals(addr->shv_service_id,
> +                             hvsk->local_addr.shv_service_id))
> +                     return hvsock_to_sk(hvsk);
> +     }
> +     return NULL;
> +}
> +
> +static struct sock *hvsock_find_connected_socket_by_channel(
> +     const struct vmbus_channel *channel)
> +{
> +     struct hvsock_sock *hvsk;
> +
> +     list_for_each_entry(hvsk, &hvsock_connected_list, connected_list) {
> +             if (hvsk->channel == channel)
> +                     return hvsock_to_sk(hvsk);
> +     }
> +     return NULL;
> +}
> +
> +static void hvsock_enqueue_accept(struct sock *listener,
> +                               struct sock *connected)
> +{
> +     struct hvsock_sock *hvconnected;
> +     struct hvsock_sock *hvlistener;
> +
> +     hvlistener = sk_to_hvsock(listener);
> +     hvconnected = sk_to_hvsock(connected);
> +
> +     sock_hold(connected);
> +     sock_hold(listener);
> +
> +     mutex_lock(&hvlistener->accept_queue_mutex);
> +     list_add_tail(&hvconnected->accept_queue, &hvlistener->accept_queue);
> +     listener->sk_ack_backlog++;
> +     mutex_unlock(&hvlistener->accept_queue_mutex);
> +}
> +
> +static struct sock *hvsock_dequeue_accept(struct sock *listener)
> +{
> +     struct hvsock_sock *hvconnected;
> +     struct hvsock_sock *hvlistener;
> +
> +     hvlistener = sk_to_hvsock(listener);
> +
> +     mutex_lock(&hvlistener->accept_queue_mutex);
> +
> +     if (list_empty(&hvlistener->accept_queue)) {
> +             mutex_unlock(&hvlistener->accept_queue_mutex);
> +             return NULL;
> +     }
> +
> +     hvconnected = list_entry(hvlistener->accept_queue.next,
> +                              struct hvsock_sock, accept_queue);
> +
> +     list_del_init(&hvconnected->accept_queue);
> +     listener->sk_ack_backlog--;



> +
> +     mutex_unlock(&hvlistener->accept_queue_mutex);
> +
> +     sock_put(listener);
> +     /* The caller will need a reference on the connected socket so we let
> +      * it call sock_put().
> +      */
> +
> +     return hvsock_to_sk(hvconnected);
> +}
> +
> +static bool hvsock_is_accept_queue_empty(struct sock *sk)
> +{
> +     struct hvsock_sock *hvsk = sk_to_hvsock(sk);
> +     int ret;
> +
> +     mutex_lock(&hvsk->accept_queue_mutex);
> +     ret = list_empty(&hvsk->accept_queue);
> +     mutex_unlock(&hvsk->accept_queue_mutex);
> +
> +     return ret;
> +}
> +
> +static void hvsock_addr_init(struct sockaddr_hv *addr, uuid_le service_id)
> +{
> +     memset(addr, 0, sizeof(*addr));
> +     addr->shv_family = AF_HYPERV;
> +     addr->shv_service_id = service_id;
> +}
> +
> +static int hvsock_addr_validate(const struct sockaddr_hv *addr)
> +{
> +     if (!addr)
> +             return -EFAULT;
> +
> +     if (addr->shv_family != AF_HYPERV)
> +             return -EAFNOSUPPORT;
> +
> +     if (addr->reserved != 0)
> +             return -EINVAL;
> +
> +     return 0;
> +}
> +
> +static bool hvsock_addr_bound(const struct sockaddr_hv *addr)
> +{
> +     return !uuid_equals(addr->shv_service_id, SHV_SERVICE_ID_ANY);
> +}
> +
> +static int hvsock_addr_cast(const struct sockaddr *addr, size_t len,
> +                         struct sockaddr_hv **out_addr)
> +{
> +     if (len < sizeof(**out_addr))
> +             return -EFAULT;
> +
> +     *out_addr = (struct sockaddr_hv *)addr;
> +     return hvsock_addr_validate(*out_addr);
> +}
> +
> +static int __hvsock_do_bind(struct hvsock_sock *hvsk,
> +                         struct sockaddr_hv *addr)
> +{
> +     struct sockaddr_hv hv_addr;
> +     int ret = 0;
> +
> +     hvsock_addr_init(&hv_addr, addr->shv_service_id);
> +
> +     mutex_lock(&hvsock_mutex);
> +
> +     if (uuid_equals(addr->shv_service_id, SHV_SERVICE_ID_ANY)) {
> +             do {
> +                     uuid_le_gen(&hv_addr.shv_service_id);
> +             } while (hvsock_find_bound_socket(&hv_addr));
> +     } else {
> +             if (hvsock_find_bound_socket(&hv_addr)) {
> +                     ret = -EADDRINUSE;
> +                     goto out;
> +             }
> +     }
> +
> +     hvsock_addr_init(&hvsk->local_addr, hv_addr.shv_service_id);
> +
> +     sock_hold(&hvsk->sk);
> +     list_add(&hvsk->bound_list, &hvsock_bound_list);
> +out:
> +     mutex_unlock(&hvsock_mutex);
> +
> +     return ret;
> +}
> +
> +static int __hvsock_bind(struct sock *sk, struct sockaddr_hv *addr)
> +{
> +     struct hvsock_sock *hvsk = sk_to_hvsock(sk);
> +     int ret;
> +
> +     if (hvsock_addr_bound(&hvsk->local_addr))
> +             return -EINVAL;
> +
> +     switch (sk->sk_socket->type) {
> +     case SOCK_STREAM:
> +             ret = __hvsock_do_bind(hvsk, addr);
> +             break;
> +
> +     default:
> +             ret = -EINVAL;
> +             break;
> +     }
> +
> +     return ret;
> +}
> +
> +/* Autobind this socket to the local address if necessary. */
> +static int hvsock_auto_bind(struct hvsock_sock *hvsk)
> +{
> +     struct sock *sk = hvsock_to_sk(hvsk);
> +     struct sockaddr_hv local_addr;
> +
> +     if (hvsock_addr_bound(&hvsk->local_addr))
> +             return 0;
> +     hvsock_addr_init(&local_addr, SHV_SERVICE_ID_ANY);
> +     return __hvsock_bind(sk, &local_addr);
> +}
> +
> +static void hvsock_sk_destruct(struct sock *sk)
> +{
> +     struct vmbus_channel *channel;
> +     struct hvsock_sock *hvsk;
> +
> +     hvsk = sk_to_hvsock(sk);
> +     vfree(hvsk->send);
> +     vfree(hvsk->recv);
> +
> +     channel = hvsk->channel;
> +     if (!channel)
> +             return;
> +
> +     vmbus_hvsock_device_unregister(channel);
> +}
> +
> +static void __hvsock_release(struct sock *sk)
> +{
> +     struct hvsock_sock *hvsk;
> +     struct sock *pending;
> +
> +     hvsk = sk_to_hvsock(sk);
> +
> +     mutex_lock(&hvsock_mutex);
> +
> +     if (!list_empty(&hvsk->bound_list)) {
> +             list_del_init(&hvsk->bound_list);
> +             sock_put(&hvsk->sk);
> +     }
> +
> +     if (!list_empty(&hvsk->connected_list)) {
> +             list_del_init(&hvsk->connected_list);
> +             sock_put(&hvsk->sk);
> +     }
> +
> +     mutex_unlock(&hvsock_mutex);
> +
> +     lock_sock(sk);
> +     sock_orphan(sk);
> +     sk->sk_shutdown = SHUTDOWN_MASK;
> +
> +     /* Clean up any sockets that never were accepted. */
> +     while ((pending = hvsock_dequeue_accept(sk)) != NULL) {
> +             __hvsock_release(pending);
> +             sock_put(pending);
> +     }
> +
> +     release_sock(sk);
> +     sock_put(sk);
> +}
> +
> +static int hvsock_release(struct socket *sock)
> +{
> +     /* If accept() is interrupted by a signal, the temporary socket
> +      * struct's sock->sk is NULL.
> +      */
> +     if (sock->sk) {
> +             __hvsock_release(sock->sk);
> +             sock->sk = NULL;
> +     }
> +
> +     sock->state = SS_FREE;
> +     return 0;
> +}
> +
> +static struct sock *hvsock_create(struct net *net, struct socket *sock,
> +                               gfp_t priority, unsigned short type)
> +{
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +
> +     sk = sk_alloc(net, AF_HYPERV, priority, &hvsock_proto, 0);
> +     if (!sk)
> +             return NULL;
> +
> +     sock_init_data(sock, sk);
> +
> +     /* sk->sk_type is normally set in sock_init_data, but only if sock
> +      * is non-NULL. We make sure that our sockets always have a type by
> +      * setting it here if needed.
> +      */
> +     if (!sock)
> +             sk->sk_type = type;
> +
> +     sk->sk_destruct = hvsock_sk_destruct;
> +
> +     /* Looks stream-based socket doesn't need this. */
> +     sk->sk_backlog_rcv = NULL;
> +
> +     sk->sk_state = 0;
> +     sock_reset_flag(sk, SOCK_DONE);
> +
> +     hvsk = sk_to_hvsock(sk);
> +
> +     hvsk->send = NULL;
> +     hvsk->recv = NULL;
> +
> +     hvsock_addr_init(&hvsk->local_addr, SHV_SERVICE_ID_ANY);
> +     hvsock_addr_init(&hvsk->remote_addr, SHV_SERVICE_ID_ANY);
> +
> +     INIT_LIST_HEAD(&hvsk->bound_list);
> +     INIT_LIST_HEAD(&hvsk->connected_list);
> +
> +     INIT_LIST_HEAD(&hvsk->accept_queue);
> +     mutex_init(&hvsk->accept_queue_mutex);
> +
> +     hvsk->peer_shutdown = 0;
> +
> +     return sk;
> +}
> +
> +static int hvsock_bind(struct socket *sock, struct sockaddr *addr,
> +                    int addr_len)
> +{
> +     struct sockaddr_hv *hv_addr;
> +     struct sock *sk;
> +     int ret;
> +
> +     sk = sock->sk;
> +
> +     if (hvsock_addr_cast(addr, addr_len, &hv_addr) != 0)
> +             return -EINVAL;
> +
> +     if (!uuid_equals(hv_addr->shv_vm_id, NULL_UUID_LE))
> +             return -EINVAL;
> +
> +     lock_sock(sk);
> +     ret = __hvsock_bind(sk, hv_addr);
> +     release_sock(sk);
> +
> +     return ret;
> +}
> +
> +static int hvsock_getname(struct socket *sock,
> +                       struct sockaddr *addr, int *addr_len, int peer)
> +{
> +     struct sockaddr_hv *hv_addr;
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +     int ret;
> +
> +     sk = sock->sk;
> +     hvsk = sk_to_hvsock(sk);
> +     ret = 0;
> +
> +     lock_sock(sk);
> +
> +     if (peer) {
> +             if (sock->state != SS_CONNECTED) {
> +                     ret = -ENOTCONN;
> +                     goto out;
> +             }
> +             hv_addr = &hvsk->remote_addr;
> +     } else {
> +             hv_addr = &hvsk->local_addr;
> +     }
> +
> +     __sockaddr_check_size(sizeof(*hv_addr));
> +
> +     memcpy(addr, hv_addr, sizeof(*hv_addr));
> +     *addr_len = sizeof(*hv_addr);
> +
> +out:
> +     release_sock(sk);
> +     return ret;
> +}
> +
> +static void get_ringbuffer_rw_status(struct vmbus_channel *channel,
> +                                  bool *can_read, bool *can_write)
> +{
> +     u32 avl_read_bytes, avl_write_bytes, dummy;
> +
> +     if (can_read) {
> +             hv_get_ringbuffer_availbytes(&channel->inbound,
> +                                          &avl_read_bytes,
> +                                          &dummy);
> +             /* 0-size payload means FIN */
> +             *can_read = avl_read_bytes >= HVSOCK_PKT_LEN(0);
> +     }
> +
> +     if (can_write) {
> +             hv_get_ringbuffer_availbytes(&channel->outbound,
> +                                          &dummy,
> +                                          &avl_write_bytes);
> +
> +             /* We only write if there is enough space */
> +             *can_write = avl_write_bytes > HVSOCK_PKT_LEN(PAGE_SIZE);
> +     }
> +}
> +
> +static size_t get_ringbuffer_writable_bytes(struct vmbus_channel *channel)
> +{
> +     u32 avl_write_bytes, dummy;
> +     size_t ret;
> +
> +     hv_get_ringbuffer_availbytes(&channel->outbound,
> +                                  &dummy,
> +                                  &avl_write_bytes);
> +
> +     /* The ringbuffer mustn't be 100% full, and we should reserve a
> +      * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
> +      * and hvsock_shutdown().
> +      */
> +     if (avl_write_bytes < HVSOCK_PKT_LEN(1) + HVSOCK_PKT_LEN(0))
> +             return 0;
> +     ret = avl_write_bytes - HVSOCK_PKT_LEN(1) - HVSOCK_PKT_LEN(0);
> +
> +     return round_down(ret, 8);
> +}
> +
> +static int hvsock_get_send_buf(struct hvsock_sock *hvsk)
> +{
> +     hvsk->send = vmalloc(sizeof(*hvsk->send));
> +     return hvsk->send ? 0 : -ENOMEM;
> +}
> +
> +static void hvsock_put_send_buf(struct hvsock_sock *hvsk)
> +{
> +     vfree(hvsk->send);
> +     hvsk->send = NULL;
> +}
> +
> +static int hvsock_send_data(struct vmbus_channel *channel,
> +                         struct hvsock_sock *hvsk,
> +                         size_t to_write)
> +{
> +     hvsk->send->hdr.pkt_type = 1;
> +     hvsk->send->hdr.data_size = to_write;
> +     return vmbus_sendpacket(channel, &hvsk->send->hdr,
> +                             sizeof(hvsk->send->hdr) + to_write,
> +                             0, VM_PKT_DATA_INBAND, 0);
> +}
> +
> +static int hvsock_get_recv_buf(struct hvsock_sock *hvsk)
> +{
> +     hvsk->recv = vmalloc(sizeof(*hvsk->recv));
> +     return hvsk->recv ? 0 : -ENOMEM;
> +}
> +
> +static void hvsock_put_recv_buf(struct hvsock_sock *hvsk)
> +{
> +     vfree(hvsk->recv);
> +     hvsk->recv = NULL;
> +}
> +
> +static int hvsock_recv_data(struct vmbus_channel *channel,
> +                         struct hvsock_sock *hvsk,
> +                         size_t *payload_len)
> +{
> +     u32 buffer_actual_len;
> +     u64 dummy_req_id;
> +     int ret;
> +
> +     ret = vmbus_recvpacket(channel, &hvsk->recv->hdr,
> +                            sizeof(hvsk->recv->hdr) +
> +                            sizeof(hvsk->recv->buf),
> +                            &buffer_actual_len, &dummy_req_id);
> +     if (ret != 0 || buffer_actual_len <= sizeof(hvsk->recv->hdr))
> +             *payload_len = 0;
> +     else
> +             *payload_len = hvsk->recv->hdr.data_size;
> +
> +     return ret;
> +}
> +
> +static int hvsock_shutdown(struct socket *sock, int mode)
> +{
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +     int ret = 0;
> +
> +     if (mode < SHUT_RD || mode > SHUT_RDWR)
> +             return -EINVAL;
> +     /* This maps:
> +      * SHUT_RD   (0) -> RCV_SHUTDOWN  (1)
> +      * SHUT_WR   (1) -> SEND_SHUTDOWN (2)
> +      * SHUT_RDWR (2) -> SHUTDOWN_MASK (3)
> +      */
> +     ++mode;
> +
> +     if (sock->state != SS_CONNECTED)
> +             return -ENOTCONN;
> +
> +     sock->state = SS_DISCONNECTING;
> +
> +     sk = sock->sk;
> +
> +     lock_sock(sk);
> +
> +     sk->sk_shutdown |= mode;
> +     sk->sk_state_change(sk);
> +
> +     if (mode & SEND_SHUTDOWN) {
> +             hvsk = sk_to_hvsock(sk);
> +
> +             ret = hvsock_get_send_buf(hvsk);
> +             if (ret < 0)
> +                     goto out;
> +
> +             /* It can't fail: see get_ringbuffer_writable_bytes(). */
> +             (void)hvsock_send_data(hvsk->channel, hvsk, 0);
> +
> +             hvsock_put_send_buf(hvsk);
> +     }
> +
> +out:
> +     release_sock(sk);
> +
> +     return ret;
> +}
> +
> +static unsigned int hvsock_poll(struct file *file, struct socket *sock,
> +                             poll_table *wait)
> +{
> +     struct vmbus_channel *channel;
> +     bool can_read, can_write;
> +     struct hvsock_sock *hvsk;
> +     unsigned int mask;
> +     struct sock *sk;
> +
> +     sk = sock->sk;
> +     hvsk = sk_to_hvsock(sk);
> +
> +     poll_wait(file, sk_sleep(sk), wait);
> +     mask = 0;
> +
> +     if (sk->sk_err)
> +             /* Signify that there has been an error on this socket. */
> +             mask |= POLLERR;
> +
> +     /* INET sockets treat local write shutdown and peer write shutdown as a
> +      * case of POLLHUP set.
> +      */
> +     if ((sk->sk_shutdown == SHUTDOWN_MASK) ||
> +         ((sk->sk_shutdown & SEND_SHUTDOWN) &&
> +          (hvsk->peer_shutdown & SEND_SHUTDOWN))) {
> +             mask |= POLLHUP;
> +     }
> +
> +     if (sk->sk_shutdown & RCV_SHUTDOWN ||
> +         hvsk->peer_shutdown & SEND_SHUTDOWN) {
> +             mask |= POLLRDHUP;
> +     }
> +
> +     lock_sock(sk);
> +
> +     /* Listening sockets that have connections in their accept
> +      * queue can be read.
> +      */
> +     if (sk->sk_state == SS_LISTEN && !hvsock_is_accept_queue_empty(sk))
> +             mask |= POLLIN | POLLRDNORM;
> +
> +     /* The mutex is to against hvsock_open_connection() */
> +     mutex_lock(&hvsock_mutex);
> +
> +     channel = hvsk->channel;
> +     if (channel) {
> +             /* If there is something in the queue then we can read */
> +             get_ringbuffer_rw_status(channel, &can_read, &can_write);
> +
> +             if (!can_read && hvsk->recv)
> +                     can_read = true;
> +
> +             if (!(sk->sk_shutdown & RCV_SHUTDOWN) && can_read)
> +                     mask |= POLLIN | POLLRDNORM;
> +     } else {
> +             can_read = false;

we don't use can_read below

> +             can_write = false;
> +     }
> +
> +     mutex_unlock(&hvsock_mutex);
> +
> +     /* Sockets whose connections have been closed terminated should
> +      * also be considered read, and we check the shutdown flag for that.
> +      */
> +     if (sk->sk_shutdown & RCV_SHUTDOWN ||
> +         hvsk->peer_shutdown & SEND_SHUTDOWN) {
> +             mask |= POLLIN | POLLRDNORM;
> +     }
> +
> +     /* Connected sockets that can produce data can be written. */
> +     if (sk->sk_state == SS_CONNECTED && can_write &&
> +         !(sk->sk_shutdown & SEND_SHUTDOWN)) {
> +             /* Remove POLLWRBAND since INET sockets are not setting it.
> +              */
> +             mask |= POLLOUT | POLLWRNORM;
> +     }
> +
> +     /* Simulate INET socket poll behaviors, which sets
> +      * POLLOUT|POLLWRNORM when peer is closed and nothing to read,
> +      * but local send is not shutdown.
> +      */
> +     if (sk->sk_state == SS_UNCONNECTED &&
> +         !(sk->sk_shutdown & SEND_SHUTDOWN))
> +             mask |= POLLOUT | POLLWRNORM;
> +
> +     release_sock(sk);
> +
> +     return mask;
> +}
> +
> +/* This function runs in the tasklet context of process_chn_event() */
> +static void hvsock_on_channel_cb(void *ctx)
> +{
> +     struct sock *sk = (struct sock *)ctx;
> +     struct vmbus_channel *channel;
> +     struct hvsock_sock *hvsk;
> +     bool can_read, can_write;
> +
> +     hvsk = sk_to_hvsock(sk);
> +     channel = hvsk->channel;
> +     if (!channel) {
> +             WARN_ONCE(1, "NULL channel! There is a programming bug.\n");

BUG() then

> +             return;
> +     }
> +
> +     get_ringbuffer_rw_status(channel, &can_read, &can_write);
> +
> +     if (can_read)
> +             sk->sk_data_ready(sk);
> +
> +     if (can_write)
> +             sk->sk_write_space(sk);
> +}
> +
> +static void hvsock_close_connection(struct vmbus_channel *channel)
> +{
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +
> +     mutex_lock(&hvsock_mutex);
> +
> +     sk = hvsock_find_connected_socket_by_channel(channel);
> +
> +     /* The guest has already closed the connection? */
> +     if (!sk)
> +             goto out;
> +
> +     sk->sk_state = SS_UNCONNECTED;
> +     sock_set_flag(sk, SOCK_DONE);
> +
> +     hvsk = sk_to_hvsock(sk);
> +     hvsk->peer_shutdown |= SEND_SHUTDOWN | RCV_SHUTDOWN;
> +
> +     sk->sk_state_change(sk);
> +out:
> +     mutex_unlock(&hvsock_mutex);
> +}
> +
> +static int hvsock_open_connection(struct vmbus_channel *channel)
> +{
> +     struct hvsock_sock *hvsk, *new_hvsk;
> +     uuid_le *instance, *service_id;
> +     unsigned char conn_from_host;
> +     struct sockaddr_hv hv_addr;
> +     struct sock *sk, *new_sk;
> +     int ret;
> +
> +     instance = &channel->offermsg.offer.if_instance;
> +     service_id = &channel->offermsg.offer.if_type;
> +
> +     /* The first byte != 0 means the host initiated the connection. */
> +     conn_from_host = channel->offermsg.offer.u.pipe.user_def[0];
> +
> +     mutex_lock(&hvsock_mutex);
> +
> +     hvsock_addr_init(&hv_addr, conn_from_host ? *service_id : *instance);
> +     sk = hvsock_find_bound_socket(&hv_addr);
> +
> +     if (!sk || (conn_from_host && sk->sk_state != SS_LISTEN) ||
> +         (!conn_from_host && sk->sk_state != SS_CONNECTING)) {
> +             ret = -ENXIO;
> +             goto out;
> +     }
> +
> +     if (conn_from_host) {
> +             if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) {
> +                     ret = -EMFILE;

I'm not sure -EMFILE is appropriate, we don't really have "too many open
files". 

> +                     goto out;
> +             }
> +
> +             new_sk = hvsock_create(sock_net(sk), NULL, GFP_KERNEL,
> +                                    sk->sk_type);
> +             if (!new_sk) {
> +                     ret = -ENOMEM;
> +                     goto out;
> +             }
> +
> +             new_sk->sk_state = SS_CONNECTING;
> +             new_hvsk = sk_to_hvsock(new_sk);
> +             new_hvsk->channel = channel;
> +             hvsock_addr_init(&new_hvsk->local_addr, *service_id);
> +             hvsock_addr_init(&new_hvsk->remote_addr, *instance);
> +     } else {
> +             hvsk = sk_to_hvsock(sk);
> +             hvsk->channel = channel;
> +     }
> +
> +     set_channel_read_state(channel, false);
> +     ret = vmbus_open(channel, RINGBUFFER_HVSOCK_SND_SIZE,
> +                      RINGBUFFER_HVSOCK_RCV_SIZE, NULL, 0,
> +                      hvsock_on_channel_cb, conn_from_host ? new_sk : sk);
> +     if (ret != 0) {
> +             if (conn_from_host) {
> +                     new_hvsk->channel = NULL;
> +                     sock_put(new_sk);
> +             } else {
> +                     hvsk->channel = NULL;
> +             }
> +             goto out;
> +     }
> +
> +     vmbus_set_chn_rescind_callback(channel, hvsock_close_connection);
> +
> +     /* see get_ringbuffer_rw_status() */
> +     set_channel_pending_send_size(channel, HVSOCK_PKT_LEN(PAGE_SIZE) + 1);
> +
> +     if (conn_from_host) {
> +             new_sk->sk_state = SS_CONNECTED;
> +
> +             sock_hold(&new_hvsk->sk);
> +             list_add(&new_hvsk->connected_list, &hvsock_connected_list);
> +
> +             hvsock_enqueue_accept(sk, new_sk);
> +     } else {
> +             sk->sk_state = SS_CONNECTED;
> +             sk->sk_socket->state = SS_CONNECTED;
> +
> +             sock_hold(&hvsk->sk);
> +             list_add(&hvsk->connected_list, &hvsock_connected_list);
> +     }
> +
> +     sk->sk_state_change(sk);
> +out:
> +     mutex_unlock(&hvsock_mutex);
> +     return ret;
> +}
> +
> +static void hvsock_connect_timeout(struct work_struct *work)
> +{
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +
> +     hvsk = container_of(work, struct hvsock_sock, dwork.work);
> +     sk = hvsock_to_sk(hvsk);
> +
> +     lock_sock(sk);
> +     if ((sk->sk_state == SS_CONNECTING) &&
> +         (sk->sk_shutdown != SHUTDOWN_MASK)) {
> +             sk->sk_state = SS_UNCONNECTED;
> +             sk->sk_err = ETIMEDOUT;
> +             sk->sk_error_report(sk);
> +     }
> +     release_sock(sk);
> +
> +     sock_put(sk);
> +}
> +
> +static int hvsock_connect_wait(struct socket *sock,
> +                            int flags, int current_ret)
> +{
> +     struct sock *sk = sock->sk;
> +     struct hvsock_sock *hvsk;
> +     int ret = current_ret;
> +     DEFINE_WAIT(wait);
> +     long timeout;
> +
> +     hvsk = sk_to_hvsock(sk);
> +     timeout = 30 * HZ;

We may want to introduce a define for this timeout. Does it actually
match host's timeout?

> +     prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
> +
> +     while (sk->sk_state != SS_CONNECTED && sk->sk_err == 0) {
> +             if (flags & O_NONBLOCK) {
> +                     /* If we're not going to block, we schedule a timeout
> +                      * function to generate a timeout on the connection
> +                      * attempt, in case the peer doesn't respond in a
> +                      * timely manner. We hold on to the socket until the
> +                      * timeout fires.
> +                      */
> +                     sock_hold(sk);
> +                     INIT_DELAYED_WORK(&hvsk->dwork,
> +                                       hvsock_connect_timeout);
> +                     schedule_delayed_work(&hvsk->dwork, timeout);
> +
> +                     /* Skip ahead to preserve error code set above. */
> +                     goto out_wait;
> +             }
> +
> +             release_sock(sk);
> +             timeout = schedule_timeout(timeout);
> +             lock_sock(sk);
> +
> +             if (signal_pending(current)) {
> +                     ret = sock_intr_errno(timeout);
> +                     goto out_wait_error;
> +             } else if (timeout == 0) {
> +                     ret = -ETIMEDOUT;
> +                     goto out_wait_error;
> +             }
> +
> +             prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
> +     }
> +
> +     ret = sk->sk_err ? -sk->sk_err : 0;
> +
> +out_wait_error:
> +     if (ret < 0) {
> +             sk->sk_state = SS_UNCONNECTED;
> +             sock->state = SS_UNCONNECTED;
> +     }
> +out_wait:
> +     finish_wait(sk_sleep(sk), &wait);
> +     return ret;
> +}
> +
> +static int hvsock_connect(struct socket *sock, struct sockaddr *addr,
> +                       int addr_len, int flags)
> +{
> +     struct sockaddr_hv *remote_addr;
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +     int ret = 0;
> +
> +     sk = sock->sk;
> +     hvsk = sk_to_hvsock(sk);
> +
> +     lock_sock(sk);
> +
> +     switch (sock->state) {
> +     case SS_CONNECTED:
> +             ret = -EISCONN;
> +             goto out;
> +     case SS_DISCONNECTING:
> +             ret = -EINVAL;
> +             goto out;
> +     case SS_CONNECTING:
> +             /* This continues on so we can move sock into the SS_CONNECTED
> +              * state once the connection has completed (at which point err
> +              * will be set to zero also).  Otherwise, we will either wait
> +              * for the connection or return -EALREADY should this be a
> +              * non-blocking call.
> +              */
> +             ret = -EALREADY;
> +             break;
> +     default:
> +             if ((sk->sk_state == SS_LISTEN) ||
> +                 hvsock_addr_cast(addr, addr_len, &remote_addr) != 0) {
> +                     ret = -EINVAL;
> +                     goto out;
> +             }
> +
> +             /* Set the remote address that we are connecting to. */
> +             memcpy(&hvsk->remote_addr, remote_addr,
> +                    sizeof(hvsk->remote_addr));
> +
> +             ret = hvsock_auto_bind(hvsk);
> +             if (ret)
> +                     goto out;
> +
> +             sk->sk_state = SS_CONNECTING;
> +
> +             ret = vmbus_send_tl_connect_request(
> +                                     &hvsk->local_addr.shv_service_id,
> +                                     &hvsk->remote_addr.shv_service_id);
> +             if (ret < 0)
> +                     goto out;
> +
> +             /* Mark sock as connecting and set the error code to in
> +              * progress in case this is a non-blocking connect.
> +              */
> +             sock->state = SS_CONNECTING;
> +             ret = -EINPROGRESS;
> +     }
> +
> +     ret = hvsock_connect_wait(sock, flags, ret);
> +out:
> +     release_sock(sk);
> +     return ret;
> +}
> +
> +static int hvsock_accept_wait(struct sock *listener,
> +                           struct socket *newsock, int flags)
> +{
> +     struct hvsock_sock *hvconnected;
> +     struct sock *connected;
> +
> +     DEFINE_WAIT(wait);
> +     long timeout;
> +
> +     int ret = 0;
> +
> +     /* Wait for children sockets to appear; these are the new sockets
> +      * created upon connection establishment.
> +      */
> +     timeout = sock_sndtimeo(listener, flags & O_NONBLOCK);
> +     prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
> +
> +     while ((connected = hvsock_dequeue_accept(listener)) == NULL &&
> +            listener->sk_err == 0) {
> +             release_sock(listener);
> +             timeout = schedule_timeout(timeout);
> +             lock_sock(listener);
> +
> +             if (signal_pending(current)) {
> +                     ret = sock_intr_errno(timeout);
> +                     goto out_wait;
> +             } else if (timeout == 0) {
> +                     ret = -EAGAIN;
> +                     goto out_wait;
> +             }
> +
> +             prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
> +     }
> +
> +     if (listener->sk_err)
> +             ret = -listener->sk_err;
> +
> +     if (connected) {
> +             lock_sock(connected);
> +             hvconnected = sk_to_hvsock(connected);
> +
> +             if (ret) {
> +                     release_sock(connected);
> +                     sock_put(connected);
> +             } else {
> +                     newsock->state = SS_CONNECTED;
> +                     sock_graft(connected, newsock);
> +                     release_sock(connected);
> +                     sock_put(connected);

so we do release_sock()/sock_put() unconditionally and this piece could
be rewritten as

    if (!ret) {
        newsock->state = SS_CONNECTED;
        sock_graft(connected, newsock);
    }
    release_sock(connected);
    sock_put(connected);

> +             }
> +     }
> +
> +out_wait:
> +     finish_wait(sk_sleep(listener), &wait);
> +     return ret;
> +}
> +
> +static int hvsock_accept(struct socket *sock, struct socket *newsock,
> +                      int flags)
> +{
> +     struct sock *listener;
> +     int ret;
> +
> +     listener = sock->sk;
> +
> +     lock_sock(listener);
> +
> +     if (sock->type != SOCK_STREAM) {
> +             ret = -EOPNOTSUPP;
> +             goto out;
> +     }
> +
> +     if (listener->sk_state != SS_LISTEN) {
> +             ret = -EINVAL;
> +             goto out;
> +     }
> +
> +     ret = hvsock_accept_wait(listener, newsock, flags);
> +out:
> +     release_sock(listener);
> +     return ret;
> +}
> +
> +static int hvsock_listen(struct socket *sock, int backlog)
> +{
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +     int ret = 0;
> +
> +     sk = sock->sk;
> +     lock_sock(sk);
> +
> +     if (sock->type != SOCK_STREAM) {
> +             ret = -EOPNOTSUPP;
> +             goto out;
> +     }
> +
> +     if (sock->state != SS_UNCONNECTED) {
> +             ret = -EINVAL;
> +             goto out;
> +     }
> +
> +     if (backlog <= 0) {
> +             ret = -EINVAL;
> +             goto out;
> +     }
> +     /* This is an artificial limit */
> +     if (backlog > 128)
> +             backlog = 128;

Let's do a define for it.

> +
> +     hvsk = sk_to_hvsock(sk);
> +     if (!hvsock_addr_bound(&hvsk->local_addr)) {
> +             ret = -EINVAL;
> +             goto out;
> +     }
> +
> +     sk->sk_ack_backlog = 0;
> +     sk->sk_max_ack_backlog = backlog;
> +     sk->sk_state = SS_LISTEN;
> +out:
> +     release_sock(sk);
> +     return ret;
> +}
> +
> +static int hvsock_sendmsg_wait(struct sock *sk, struct msghdr *msg,
> +                            size_t len)
> +{
> +     struct hvsock_sock *hvsk = sk_to_hvsock(sk);
> +     struct vmbus_channel *channel;
> +     size_t total_to_write = len;
> +     size_t total_written = 0;
> +     DEFINE_WAIT(wait);
> +     bool can_write;
> +     long timeout;
> +     int ret = 0;
> +
> +     timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
> +     prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
> +     channel = hvsk->channel;
> +
> +     while (total_to_write > 0) {
> +             size_t to_write, max_writable;
> +
> +             while (1) {
> +                     get_ringbuffer_rw_status(channel, NULL, &can_write);
> +
> +                     if (can_write || sk->sk_err != 0 ||
> +                         (sk->sk_shutdown & SEND_SHUTDOWN) ||
> +                         (hvsk->peer_shutdown & RCV_SHUTDOWN))
> +                             break;
> +
> +                     /* Don't wait for non-blocking sockets. */
> +                     if (timeout == 0) {
> +                             ret = -EAGAIN;
> +                             goto out_wait;
> +                     }
> +
> +                     release_sock(sk);
> +
> +                     timeout = schedule_timeout(timeout);
> +
> +                     lock_sock(sk);
> +                     if (signal_pending(current)) {
> +                             ret = sock_intr_errno(timeout);
> +                             goto out_wait;
> +                     } else if (timeout == 0) {
> +                             ret = -EAGAIN;
> +                             goto out_wait;
> +                     }
> +
> +                     prepare_to_wait(sk_sleep(sk), &wait,
> +                                     TASK_INTERRUPTIBLE);
> +             }
> +
> +             /* These checks occur both as part of and after the loop
> +              * conditional since we need to check before and after
> +              * sleeping.
> +              */
> +             if (sk->sk_err) {
> +                     ret = -sk->sk_err;
> +                     goto out_wait;
> +             } else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
> +                        (hvsk->peer_shutdown & RCV_SHUTDOWN)) {
> +                     ret = -EPIPE;
> +                     goto out_wait;
> +             }
> +
> +             /* Note: that write will only write as many bytes as possible
> +              * in the ringbuffer. It is the caller's responsibility to
> +              * check how many bytes we actually wrote.
> +              */
> +             do {
> +                     max_writable = get_ringbuffer_writable_bytes(channel);
> +                     if (max_writable == 0)
> +                             goto out_wait;
> +
> +                     to_write = min_t(size_t, sizeof(hvsk->send->buf),
> +                                      total_to_write);
> +                     if (to_write > max_writable)
> +                             to_write = max_writable;
> +
> +                     ret = hvsock_get_send_buf(hvsk);
> +                     if (ret < 0)
> +                             goto out_wait;
> +
> +                     ret = memcpy_from_msg(hvsk->send->buf, msg, to_write);
> +                     if (ret != 0) {
> +                             hvsock_put_send_buf(hvsk);
> +                             goto out_wait;
> +                     }
> +
> +                     ret = hvsock_send_data(channel, hvsk, to_write);
> +                     hvsock_put_send_buf(hvsk);
> +                     if (ret != 0)
> +                             goto out_wait;
> +
> +                     total_written += to_write;
> +                     total_to_write -= to_write;
> +             } while (total_to_write > 0);
> +     }
> +
> +out_wait:
> +     if (total_written > 0)
> +             ret = total_written;
> +
> +     finish_wait(sk_sleep(sk), &wait);
> +     return ret;
> +}
> +
> +static int hvsock_sendmsg(struct socket *sock, struct msghdr *msg,
> +                       size_t len)
> +{
> +     struct hvsock_sock *hvsk;
> +     struct sock *sk;
> +     int ret;
> +
> +     if (len == 0)
> +             return -EINVAL;
> +
> +     if (msg->msg_flags & ~MSG_DONTWAIT) {
> +             pr_err("%s: unsupported flags=0x%x\n", __func__,
> +                    msg->msg_flags);

I don't think we need pr_err() here.

> +             return -EOPNOTSUPP;
> +     }
> +
> +     sk = sock->sk;
> +     hvsk = sk_to_hvsock(sk);
> +
> +     lock_sock(sk);
> +
> +     /* Callers should not provide a destination with stream sockets. */
> +     if (msg->msg_namelen) {
> +             ret = -EOPNOTSUPP;
> +             goto out;
> +     }
> +
> +     /* Send data only if both sides are not shutdown in the direction. */
> +     if (sk->sk_shutdown & SEND_SHUTDOWN ||
> +         hvsk->peer_shutdown & RCV_SHUTDOWN) {
> +             ret = -EPIPE;
> +             goto out;
> +     }
> +
> +     if (sk->sk_state != SS_CONNECTED ||
> +         !hvsock_addr_bound(&hvsk->local_addr)) {
> +             ret = -ENOTCONN;
> +             goto out;
> +     }
> +
> +     if (!hvsock_addr_bound(&hvsk->remote_addr)) {
> +             ret = -EDESTADDRREQ;
> +             goto out;
> +     }
> +
> +     ret = hvsock_sendmsg_wait(sk, msg, len);
> +out:
> +     release_sock(sk);
> +
> +     /* ret is a bigger-than-0 total_written or a negative err code. */
> +     if (ret == 0) {
> +             WARN(1, "unexpected return value of 0\n");
> +             ret = -EIO;
> +     }

It seems hvsock_sendmsg_wait() can return 0. I see the following code there:

         max_writable = get_ringbuffer_writable_bytes(channel);
         if (max_writable == 0)
             goto out_wait;

so if there is no space on the ringbuffer we won't write
anything. WARN() is inapropriate then.

> +
> +     return ret;
> +}
> +
> +static int hvsock_recvmsg_wait(struct sock *sk, struct msghdr *msg,
> +                            size_t len, int flags)
> +{
> +     struct hvsock_sock *hvsk = sk_to_hvsock(sk);
> +     size_t to_read, total_to_read = len;
> +     struct vmbus_channel *channel;
> +     DEFINE_WAIT(wait);
> +     size_t copied = 0;
> +     bool can_read;
> +     long timeout;
> +     int ret = 0;
> +
> +     timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
> +     prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
> +     channel = hvsk->channel;
> +
> +     while (1) {
> +             bool need_refill = !hvsk->recv;
> +
> +             if (need_refill) {
> +                     if (hvsk->peer_shutdown & SEND_SHUTDOWN)
> +                             can_read = false;
> +                     else
> +                             get_ringbuffer_rw_status(channel, &can_read,
> +                                                      NULL);
> +             } else {
> +                     can_read = true;
> +             }
> +
> +             if (can_read) {
> +                     size_t payload_len;
> +
> +                     if (need_refill) {
> +                             ret = hvsock_get_recv_buf(hvsk);
> +                             if (ret < 0) {
> +                                     if (copied > 0)
> +                                             ret = copied;
> +                                     goto out_wait;
> +                             }
> +
> +                             ret = hvsock_recv_data(channel, hvsk,
> +                                                    &payload_len);
> +                             if (ret != 0 ||
> +                                 payload_len > sizeof(hvsk->recv->buf)) {
> +                                     ret = -EIO;
> +                                     hvsock_put_recv_buf(hvsk);
> +                                     goto out_wait;
> +                             }
> +
> +                             if (payload_len == 0) {
> +                                     ret = copied;
> +                                     hvsock_put_recv_buf(hvsk);
> +                                     hvsk->peer_shutdown |= SEND_SHUTDOWN;
> +                                     break;
> +                             }
> +
> +                             hvsk->recv->data_len = payload_len;
> +                             hvsk->recv->data_offset = 0;
> +                     }
> +
> +                     to_read = min_t(size_t, total_to_read,
> +                                     hvsk->recv->data_len);
> +
> +                     ret = memcpy_to_msg(msg, hvsk->recv->buf +
> +                                         hvsk->recv->data_offset,
> +                                         to_read);
> +                     if (ret != 0)
> +                             break;
> +
> +                     copied += to_read;
> +                     total_to_read -= to_read;
> +
> +                     hvsk->recv->data_len -= to_read;
> +
> +                     if (hvsk->recv->data_len == 0)
> +                             hvsock_put_recv_buf(hvsk);
> +                     else
> +                             hvsk->recv->data_offset += to_read;
> +
> +                     if (total_to_read == 0)
> +                             break;
> +             } else {
> +                     if (sk->sk_err || (sk->sk_shutdown & RCV_SHUTDOWN) ||
> +                         (hvsk->peer_shutdown & SEND_SHUTDOWN))
> +                             break;
> +
> +                     /* Don't wait for non-blocking sockets. */
> +                     if (timeout == 0) {
> +                             ret = -EAGAIN;
> +                             break;
> +                     }
> +
> +                     if (copied > 0)
> +                             break;
> +
> +                     release_sock(sk);
> +                     timeout = schedule_timeout(timeout);
> +                     lock_sock(sk);
> +
> +                     if (signal_pending(current)) {
> +                             ret = sock_intr_errno(timeout);
> +                             break;
> +                     } else if (timeout == 0) {
> +                             ret = -EAGAIN;
> +                             break;
> +                     }
> +
> +                     prepare_to_wait(sk_sleep(sk), &wait,
> +                                     TASK_INTERRUPTIBLE);
> +             }
> +     }
> +
> +     if (sk->sk_err)
> +             ret = -sk->sk_err;
> +     else if (sk->sk_shutdown & RCV_SHUTDOWN)
> +             ret = 0;
> +
> +     if (copied > 0)
> +             ret = copied;
> +out_wait:
> +     finish_wait(sk_sleep(sk), &wait);
> +     return ret;
> +}
> +
> +static int hvsock_recvmsg(struct socket *sock, struct msghdr *msg,
> +                       size_t len, int flags)
> +{
> +     struct sock *sk = sock->sk;
> +     int ret;
> +
> +     lock_sock(sk);
> +
> +     if (sk->sk_state != SS_CONNECTED) {
> +             /* Recvmsg is supposed to return 0 if a peer performs an
> +              * orderly shutdown. Differentiate between that case and when a
> +              * peer has not connected or a local shutdown occurred with the
> +              * SOCK_DONE flag.
> +              */
> +             if (sock_flag(sk, SOCK_DONE))
> +                     ret = 0;
> +             else
> +                     ret = -ENOTCONN;
> +
> +             goto out;
> +     }
> +
> +     /* We ignore msg->addr_name/len. */
> +     if (flags & ~MSG_DONTWAIT) {
> +             pr_err("%s: unsupported flags=0x%x\n", __func__, flags);

Here he may also want to drop pr_err()

> +             ret = -EOPNOTSUPP;
> +             goto out;
> +     }
> +
> +     /* We don't check peer_shutdown flag here since peer may actually shut
> +      * down, but there can be data in the queue that a local socket can
> +      * receive.
> +      */
> +     if (sk->sk_shutdown & RCV_SHUTDOWN) {
> +             ret = 0;
> +             goto out;
> +     }
> +
> +     /* It is valid on Linux to pass in a zero-length receive buffer.  This
> +      * is not an error.  We may as well bail out now.
> +      */
> +     if (!len) {
> +             ret = 0;
> +             goto out;
> +     }
> +
> +     ret = hvsock_recvmsg_wait(sk, msg, len, flags);
> +out:
> +     release_sock(sk);
> +     return ret;
> +}
> +
> +static const struct proto_ops hvsock_ops = {
> +     .family = PF_HYPERV,
> +     .owner = THIS_MODULE,
> +     .release = hvsock_release,
> +     .bind = hvsock_bind,
> +     .connect = hvsock_connect,
> +     .socketpair = sock_no_socketpair,
> +     .accept = hvsock_accept,
> +     .getname = hvsock_getname,
> +     .poll = hvsock_poll,
> +     .ioctl = sock_no_ioctl,
> +     .listen = hvsock_listen,
> +     .shutdown = hvsock_shutdown,
> +     .setsockopt = sock_no_setsockopt,
> +     .getsockopt = sock_no_getsockopt,
> +     .sendmsg = hvsock_sendmsg,
> +     .recvmsg = hvsock_recvmsg,
> +     .mmap = sock_no_mmap,
> +     .sendpage = sock_no_sendpage,
> +};
> +
> +static int hvsock_create_sock(struct net *net, struct socket *sock,
> +                           int protocol, int kern)
> +{
> +     struct sock *sk;
> +
> +     if (!capable(CAP_SYS_ADMIN) && !capable(CAP_NET_ADMIN))
> +             return -EPERM;

I'd say we're OK with CAP_SYS_ADMIN only for now and we'll be able to
drop the check when host starts supporting single pair of ringbuffers
for all Hyper-V sockets on the system.

> +
> +     if (protocol != 0 && protocol != SHV_PROTO_RAW)
> +             return -EPROTONOSUPPORT;
> +
> +     switch (sock->type) {
> +     case SOCK_STREAM:
> +             sock->ops = &hvsock_ops;
> +             break;
> +     default:
> +             return -ESOCKTNOSUPPORT;
> +     }
> +
> +     sock->state = SS_UNCONNECTED;
> +
> +     sk = hvsock_create(net, sock, GFP_KERNEL, 0);
> +     return sk ? 0 : -ENOMEM;
> +}
> +
> +static const struct net_proto_family hvsock_family_ops = {
> +     .family = AF_HYPERV,
> +     .create = hvsock_create_sock,
> +     .owner = THIS_MODULE,
> +};
> +
> +static int hvsock_probe(struct hv_device *hdev,
> +                     const struct hv_vmbus_device_id *dev_id)
> +{
> +     struct vmbus_channel *channel = hdev->channel;
> +
> +     /* We ignore the error return code to suppress the unnecessary
> +      * error message in vmbus_probe(): on error the host will rescind
> +      * the offer in 30 seconds and we can do cleanup at that time.
> +      */
> +     (void)hvsock_open_connection(channel);
> +
> +     return 0;
> +}
> +
> +static int hvsock_remove(struct hv_device *hdev)
> +{
> +     struct vmbus_channel *channel = hdev->channel;
> +
> +     vmbus_close(channel);
> +
> +     return 0;
> +}
> +
> +/* It's not really used. See vmbus_match() and vmbus_probe(). */
> +static const struct hv_vmbus_device_id id_table[] = {
> +     {},
> +};
> +
> +static struct hv_driver hvsock_drv = {
> +     .name           = "hv_sock",
> +     .hvsock         = true,
> +     .id_table       = id_table,
> +     .probe          = hvsock_probe,
> +     .remove         = hvsock_remove,
> +};
> +
> +static int __init hvsock_init(void)
> +{
> +     int ret;
> +
> +     /* Hyper-V Sockets requires at least VMBus 4.0 */
> +     if ((vmbus_proto_version >> 16) < 4)
> +             return -ENODEV;

So it's actually

  if (vmbus_proto_version < VERSION_WIN10)

I suggest we use such comparisson to be in line with other places where
vmbus_proto_version is checked.

> +
> +     ret = vmbus_driver_register(&hvsock_drv);
> +     if (ret) {
> +             pr_err("failed to register hv_sock driver\n");
> +             return ret;
> +     }
> +
> +     ret = proto_register(&hvsock_proto, 0);
> +     if (ret) {
> +             pr_err("failed to register protocol\n");
> +             goto unreg_hvsock_drv;
> +     }
> +
> +     ret = sock_register(&hvsock_family_ops);
> +     if (ret) {
> +             pr_err("failed to register address family\n");
> +             goto unreg_proto;
> +     }
> +
> +     return 0;
> +
> +unreg_proto:
> +     proto_unregister(&hvsock_proto);
> +unreg_hvsock_drv:
> +     vmbus_driver_unregister(&hvsock_drv);
> +     return ret;
> +}
> +
> +static void __exit hvsock_exit(void)
> +{
> +     sock_unregister(AF_HYPERV);
> +     proto_unregister(&hvsock_proto);
> +     vmbus_driver_unregister(&hvsock_drv);
> +}
> +
> +module_init(hvsock_init);
> +module_exit(hvsock_exit);
> +
> +MODULE_DESCRIPTION("Hyper-V Sockets");
> +MODULE_LICENSE("Dual BSD/GPL");

-- 
  Vitaly
_______________________________________________
devel mailing list
de...@linuxdriverproject.org
http://driverdev.linuxdriverproject.org/mailman/listinfo/driverdev-devel

Reply via email to