From: Arseniy Krasnov <oxfff...@gmail.com>

---
 net/vmw_vsock/af_vsock.c | 107 +++++++++++++++++++++++++++++++++------
 1 file changed, 91 insertions(+), 16 deletions(-)

diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 7ff00449a9a2..30caad9815f7 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -452,6 +452,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct 
vsock_sock *psk)
                new_transport = transport_dgram;
                break;
        case SOCK_STREAM:
+       case SOCK_SEQPACKET:
                if (vsock_use_local_transport(remote_cid))
                        new_transport = transport_local;
                else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g ||
@@ -459,6 +460,12 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct 
vsock_sock *psk)
                        new_transport = transport_g2h;
                else
                        new_transport = transport_h2g;
+
+               if (sk->sk_type == SOCK_SEQPACKET) {
+                       if (!new_transport->seqpacket_seq_send_len ||
+                           !new_transport->seqpacket_seq_get_len)
+                               return -ENODEV;
+               }
                break;
        default:
                return -ESOCKTNOSUPPORT;
@@ -604,8 +611,8 @@ static void vsock_pending_work(struct work_struct *work)
 
 /**** SOCKET OPERATIONS ****/
 
-static int __vsock_bind_stream(struct vsock_sock *vsk,
-                              struct sockaddr_vm *addr)
+static int __vsock_bind_connectible(struct vsock_sock *vsk,
+                                   struct sockaddr_vm *addr)
 {
        static u32 port;
        struct sockaddr_vm new_addr;
@@ -684,8 +691,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm 
*addr)
 
        switch (sk->sk_socket->type) {
        case SOCK_STREAM:
+       case SOCK_SEQPACKET:
                spin_lock_bh(&vsock_table_lock);
-               retval = __vsock_bind_stream(vsk, addr);
+               retval = __vsock_bind_connectible(vsk, addr);
                spin_unlock_bh(&vsock_table_lock);
                break;
 
@@ -767,6 +775,11 @@ static struct sock *__vsock_create(struct net *net,
        return sk;
 }
 
+static bool sock_type_connectible(u16 type)
+{
+       return (type == SOCK_STREAM || type == SOCK_SEQPACKET);
+}
+
 static void __vsock_release(struct sock *sk, int level)
 {
        if (sk) {
@@ -785,7 +798,7 @@ static void __vsock_release(struct sock *sk, int level)
 
                if (vsk->transport)
                        vsk->transport->release(vsk);
-               else if (sk->sk_type == SOCK_STREAM)
+               else if (sock_type_connectible(sk->sk_type))
                        vsock_remove_sock(vsk);
 
                sock_orphan(sk);
@@ -945,7 +958,7 @@ static int vsock_shutdown(struct socket *sock, int mode)
        sk = sock->sk;
        if (sock->state == SS_UNCONNECTED) {
                err = -ENOTCONN;
-               if (sk->sk_type == SOCK_STREAM)
+               if (sock_type_connectible(sk->sk_type))
                        return err;
        } else {
                sock->state = SS_DISCONNECTING;
@@ -960,7 +973,7 @@ static int vsock_shutdown(struct socket *sock, int mode)
                sk->sk_state_change(sk);
                release_sock(sk);
 
-               if (sk->sk_type == SOCK_STREAM) {
+               if (sock_type_connectible(sk->sk_type)) {
                        sock_reset_flag(sk, SOCK_DONE);
                        vsock_send_shutdown(sk, mode);
                }
@@ -1013,7 +1026,7 @@ static __poll_t vsock_poll(struct file *file, struct 
socket *sock,
                if (!(sk->sk_shutdown & SEND_SHUTDOWN))
                        mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
 
-       } else if (sock->type == SOCK_STREAM) {
+       } else if (sock_type_connectible(sk->sk_type)) {
                const struct vsock_transport *transport = vsk->transport;
                lock_sock(sk);
 
@@ -1259,8 +1272,8 @@ static void vsock_connect_timeout(struct work_struct 
*work)
        sock_put(sk);
 }
 
-static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
-                               int addr_len, int flags)
+static int vsock_connect(struct socket *sock, struct sockaddr *addr,
+                        int addr_len, int flags)
 {
        int err;
        struct sock *sk;
@@ -1410,7 +1423,7 @@ static int vsock_accept(struct socket *sock, struct 
socket *newsock, int flags,
 
        lock_sock(listener);
 
-       if (sock->type != SOCK_STREAM) {
+       if (!sock_type_connectible(sock->type)) {
                err = -EOPNOTSUPP;
                goto out;
        }
@@ -1477,6 +1490,18 @@ static int vsock_accept(struct socket *sock, struct 
socket *newsock, int flags,
        return err;
 }
 
+static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
+                               int addr_len, int flags)
+{
+       return vsock_connect(sock, addr, addr_len, flags);
+}
+
+static int vsock_seqpacket_connect(struct socket *sock, struct sockaddr *addr,
+                                  int addr_len, int flags)
+{
+       return vsock_connect(sock, addr, addr_len, flags);
+}
+
 static int vsock_listen(struct socket *sock, int backlog)
 {
        int err;
@@ -1487,7 +1512,7 @@ static int vsock_listen(struct socket *sock, int backlog)
 
        lock_sock(sk);
 
-       if (sock->type != SOCK_STREAM) {
+       if (!sock_type_connectible(sk->sk_type)) {
                err = -EOPNOTSUPP;
                goto out;
        }
@@ -1531,11 +1556,11 @@ static void vsock_update_buffer_size(struct vsock_sock 
*vsk,
        vsk->buffer_size = val;
 }
 
-static int vsock_stream_setsockopt(struct socket *sock,
-                                  int level,
-                                  int optname,
-                                  sockptr_t optval,
-                                  unsigned int optlen)
+static int vsock_setsockopt(struct socket *sock,
+                           int level,
+                           int optname,
+                           sockptr_t optval,
+                           unsigned int optlen)
 {
        int err;
        struct sock *sk;
@@ -1612,6 +1637,24 @@ static int vsock_stream_setsockopt(struct socket *sock,
        return err;
 }
 
+static int vsock_seqpacket_setsockopt(struct socket *sock,
+                                     int level,
+                                     int optname,
+                                     sockptr_t optval,
+                                     unsigned int optlen)
+{
+       return vsock_setsockopt(sock, level, optname, optval, optlen);
+}
+
+static int vsock_stream_setsockopt(struct socket *sock,
+                                  int level,
+                                  int optname,
+                                  sockptr_t optval,
+                                  unsigned int optlen)
+{
+       return vsock_setsockopt(sock, level, optname, optval, optlen);
+}
+
 static int vsock_stream_getsockopt(struct socket *sock,
                                   int level, int optname,
                                   char __user *optval,
@@ -1683,6 +1726,14 @@ static int vsock_stream_getsockopt(struct socket *sock,
        return 0;
 }
 
+static int vsock_seqpacket_getsockopt(struct socket *sock,
+                                     int level, int optname,
+                                     char __user *optval,
+                                     int __user *optlen)
+{
+       return vsock_stream_getsockopt(sock, level, optname, optval, optlen);
+}
+
 static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
                                     size_t len)
 {
@@ -2209,6 +2260,27 @@ static const struct proto_ops vsock_stream_ops = {
        .sendpage = sock_no_sendpage,
 };
 
+static const struct proto_ops vsock_seqpacket_ops = {
+       .family = PF_VSOCK,
+       .owner = THIS_MODULE,
+       .release = vsock_release,
+       .bind = vsock_bind,
+       .connect = vsock_seqpacket_connect,
+       .socketpair = sock_no_socketpair,
+       .accept = vsock_accept,
+       .getname = vsock_getname,
+       .poll = vsock_poll,
+       .ioctl = sock_no_ioctl,
+       .listen = vsock_listen,
+       .shutdown = vsock_shutdown,
+       .setsockopt = vsock_seqpacket_setsockopt,
+       .getsockopt = vsock_seqpacket_getsockopt,
+       .sendmsg = vsock_seqpacket_sendmsg,
+       .recvmsg = vsock_seqpacket_recvmsg,
+       .mmap = sock_no_mmap,
+       .sendpage = sock_no_sendpage,
+};
+
 static int vsock_create(struct net *net, struct socket *sock,
                        int protocol, int kern)
 {
@@ -2229,6 +2301,9 @@ static int vsock_create(struct net *net, struct socket 
*sock,
        case SOCK_STREAM:
                sock->ops = &vsock_stream_ops;
                break;
+       case SOCK_SEQPACKET:
+               sock->ops = &vsock_seqpacket_ops;
+               break;
        default:
                return -ESOCKTNOSUPPORT;
        }
-- 
2.25.1

Reply via email to