On Sun, Jan 03, 2021 at 10:57:50PM +0300, Arseny Krasnov wrote:
        This extends rx loop for SOCK_SEQPACKET packets and implements
callback which user calls to copy data to its buffer.

Please write a better commit message explaining the changes, e.g. that you are using 'flags' to transport lengths when 'op' ==
VIRTIO_VSOCK_OP_SEQ_BEGIN.


Signed-off-by: Arseny Krasnov <arseny.kras...@kaspersky.com>
---
include/linux/virtio_vsock.h            |   7 +
include/net/af_vsock.h                  |   4 +
include/uapi/linux/virtio_vsock.h       |   9 +
net/vmw_vsock/virtio_transport.c        |   3 +
net/vmw_vsock/virtio_transport_common.c | 323 +++++++++++++++++++++---
5 files changed, 305 insertions(+), 41 deletions(-)

diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index dc636b727179..4902d71b3252 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -36,6 +36,10 @@ struct virtio_vsock_sock {
        u32 rx_bytes;
        u32 buf_alloc;
        struct list_head rx_queue;
+
+       /* For SOCK_SEQPACKET */
+       u32 user_read_seq_len;
+       u32 user_read_copied;
};

struct virtio_vsock_pkt {
@@ -80,6 +84,9 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
                               struct msghdr *msg,
                               size_t len, int flags);

+bool virtio_transport_seqpacket_seq_send_len(struct vsock_sock *vsk, size_t 
len);
+size_t virtio_transport_seqpacket_seq_get_len(struct vsock_sock *vsk);
+
s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);

diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index b1c717286993..792ea7b66574 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -135,6 +135,10 @@ struct vsock_transport {
        bool (*stream_is_active)(struct vsock_sock *);
        bool (*stream_allow)(u32 cid, u32 port);

+       /* SEQ_PACKET. */
+       bool (*seqpacket_seq_send_len)(struct vsock_sock *, size_t len);
+       size_t (*seqpacket_seq_get_len)(struct vsock_sock *);
+

These changes are related to the vsock core, so please move to another patch.

        /* Notification. */
        int (*notify_poll_in)(struct vsock_sock *, size_t, bool *);
        int (*notify_poll_out)(struct vsock_sock *, size_t, bool *);
diff --git a/include/uapi/linux/virtio_vsock.h 
b/include/uapi/linux/virtio_vsock.h
index 1d57ed3d84d2..058908bc19fc 100644
--- a/include/uapi/linux/virtio_vsock.h
+++ b/include/uapi/linux/virtio_vsock.h
@@ -65,6 +65,7 @@ struct virtio_vsock_hdr {

enum virtio_vsock_type {
        VIRTIO_VSOCK_TYPE_STREAM = 1,
+       VIRTIO_VSOCK_TYPE_SEQPACKET = 2,
};

enum virtio_vsock_op {
@@ -83,6 +84,9 @@ enum virtio_vsock_op {
        VIRTIO_VSOCK_OP_CREDIT_UPDATE = 6,
        /* Request the peer to send the credit info to us */
        VIRTIO_VSOCK_OP_CREDIT_REQUEST = 7,
+
+       /* Record begin for SOCK_SEQPACKET */
+       VIRTIO_VSOCK_OP_SEQ_BEGIN = 8,
};

/* VIRTIO_VSOCK_OP_SHUTDOWN flags values */
@@ -91,4 +95,9 @@ enum virtio_vsock_shutdown {
        VIRTIO_VSOCK_SHUTDOWN_SEND = 2,
};

+/* VIRTIO_VSOCK_OP_RW flags values for SOCK_SEQPACKET type */
+enum virtio_vsock_rw_seqpacket {
+       VIRTIO_VSOCK_RW_EOR = 1,
+};
+
#endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 2700a63ab095..2bd3f7cbffcb 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -469,6 +469,9 @@ static struct virtio_transport virtio_transport = {
                .stream_is_active         = virtio_transport_stream_is_active,
                .stream_allow             = virtio_transport_stream_allow,

+               .seqpacket_seq_send_len   = 
virtio_transport_seqpacket_seq_send_len,
+               .seqpacket_seq_get_len    = 
virtio_transport_seqpacket_seq_get_len,
+
                .notify_poll_in           = virtio_transport_notify_poll_in,
                .notify_poll_out          = virtio_transport_notify_poll_out,
                .notify_recv_init         = virtio_transport_notify_recv_init,
diff --git a/net/vmw_vsock/virtio_transport_common.c 
b/net/vmw_vsock/virtio_transport_common.c
index 5956939eebb7..77c42004e422 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -139,6 +139,7 @@ static struct sk_buff *virtio_transport_build_skb(void 
*opaque)
                break;
        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
        case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
+       case VIRTIO_VSOCK_OP_SEQ_BEGIN:
                hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
                break;
        default:
@@ -157,6 +158,10 @@ static struct sk_buff *virtio_transport_build_skb(void 
*opaque)

void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
{
+       /* TODO: implement tap support for SOCK_SEQPACKET. */
+       if (le32_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_SEQPACKET)
+               return;
+
        if (pkt->tap_delivered)
                return;

@@ -230,10 +235,10 @@ static bool virtio_transport_inc_rx_pkt(struct 
virtio_vsock_sock *vvs,
}

static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
-                                       struct virtio_vsock_pkt *pkt)
+                                       u32 len)
{
-       vvs->rx_bytes -= pkt->len;
-       vvs->fwd_cnt += pkt->len;
+       vvs->rx_bytes -= len;
+       vvs->fwd_cnt += len;
}

void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct 
virtio_vsock_pkt *pkt)
@@ -365,7 +370,7 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
                total += bytes;
                pkt->off += bytes;
                if (pkt->off == pkt->len) {
-                       virtio_transport_dec_rx_pkt(vvs, pkt);
+                       virtio_transport_dec_rx_pkt(vvs, pkt->len);
                        list_del(&pkt->list);
                        virtio_transport_free_pkt(pkt);
                }
@@ -397,15 +402,202 @@ virtio_transport_stream_do_dequeue(struct vsock_sock 
*vsk,
        return err;
}

+static u16 virtio_transport_get_type(struct sock *sk)
+{
+       if (sk->sk_type == SOCK_STREAM)
+               return VIRTIO_VSOCK_TYPE_STREAM;
+       else
+               return VIRTIO_VSOCK_TYPE_SEQPACKET;
+}
+
+bool virtio_transport_seqpacket_seq_send_len(struct vsock_sock *vsk, size_t 
len)
+{
+       struct virtio_vsock_pkt_info info = {
+               .type = VIRTIO_VSOCK_TYPE_SEQPACKET,
+               .op = VIRTIO_VSOCK_OP_SEQ_BEGIN,
+               .vsk = vsk,
+               .flags = len
+       };
+
+       return virtio_transport_send_pkt_info(vsk, &info);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_seq_send_len);
+
+static inline void virtio_transport_del_n_free_pkt(struct virtio_vsock_pkt 
*pkt)
+{
+       list_del(&pkt->list);
+       virtio_transport_free_pkt(pkt);
+}
+
+static size_t virtio_transport_drop_until_seq_begin(struct virtio_vsock_sock 
*vvs)
+{
+       struct virtio_vsock_pkt *pkt, *n;
+       size_t bytes_dropped = 0;
+
+       list_for_each_entry_safe(pkt, n, &vvs->rx_queue, list) {
+               if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_SEQ_BEGIN)
+                       break;
+
+               bytes_dropped += le32_to_cpu(pkt->hdr.len);
+               virtio_transport_dec_rx_pkt(vvs, pkt->len);
+               virtio_transport_del_n_free_pkt(pkt);
+       }
+
+       return bytes_dropped;
+}
+
+size_t virtio_transport_seqpacket_seq_get_len(struct vsock_sock *vsk)
+{
+       struct virtio_vsock_sock *vvs = vsk->trans;
+       struct virtio_vsock_pkt *pkt;
+       size_t bytes_dropped;
+
+       spin_lock_bh(&vvs->rx_lock);
+
+       /* Fetch all orphaned 'RW', packets, and
+        * send credit update.
+        */
+       bytes_dropped = virtio_transport_drop_until_seq_begin(vvs);
+
+       if (list_empty(&vvs->rx_queue))
+               goto out;
+
+       pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
+
+       vvs->user_read_copied = 0;
+       vvs->user_read_seq_len = le32_to_cpu(pkt->hdr.flags);
+       virtio_transport_del_n_free_pkt(pkt);
+out:
+       spin_unlock_bh(&vvs->rx_lock);
+
+       if (bytes_dropped)
+               virtio_transport_send_credit_update(vsk,
+                                                   VIRTIO_VSOCK_TYPE_SEQPACKET,
+                                                   NULL);
+
+       return vvs->user_read_seq_len;
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_seq_get_len);
+
+static ssize_t virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
+                                                    struct msghdr *msg,
+                                                    size_t user_buf_len)
+{
+       struct virtio_vsock_sock *vvs = vsk->trans;
+       struct virtio_vsock_pkt *pkt;
+       size_t bytes_handled = 0;
+       int err = 0;
+
+       spin_lock_bh(&vvs->rx_lock);
+
+       if (user_buf_len == 0) {
+               /* User's buffer is full, we processing rest of
+                * record and drop it. If 'SEQ_BEGIN' is found
+                * while iterating, user will be woken up,
+                * because record is already copied, and we
+                * don't care about absent of some tail RW packets
+                * of it. Return number of bytes(rest of record),
+                * but ignore credit update for such absent bytes.
+                */
+               bytes_handled = virtio_transport_drop_until_seq_begin(vvs);
+               vvs->user_read_copied += bytes_handled;
+
+               if (!list_empty(&vvs->rx_queue) &&
+                   vvs->user_read_copied < vvs->user_read_seq_len) {
+                       /* 'SEQ_BEGIN' found, but record isn't complete.
+                        * Set number of copied bytes to fit record size
+                        * and force counters to finish receiving.
+                        */
+                       bytes_handled += (vvs->user_read_seq_len - 
vvs->user_read_copied);
+                       vvs->user_read_copied = vvs->user_read_seq_len;
+               }
+       }
+
+       /* Now start copying. */
+       while (vvs->user_read_copied < vvs->user_read_seq_len &&
+              vvs->rx_bytes &&
+              user_buf_len &&
+              !err) {
+               pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, 
list);
+
+               switch (le16_to_cpu(pkt->hdr.op)) {
+               case VIRTIO_VSOCK_OP_SEQ_BEGIN: {
+                       /* Unexpected 'SEQ_BEGIN' during record copy:
+                        * Leave receive loop, 'EAGAIN' will restart it from
+                        * outer receive loop, packet is still in queue and
+                        * counters are cleared. So in next loop enter,
+                        * 'SEQ_BEGIN' will be dequeued first. User's iov
+                        * iterator will be reset in outer loop. Also
+                        * send credit update, because some bytes could be
+                        * copied. User will never see unfinished record.
+                        */
+                       err = -EAGAIN;
+                       break;
+               }
+               case VIRTIO_VSOCK_OP_RW: {
+                       size_t bytes_to_copy;
+                       size_t pkt_len;
+
+                       pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
+                       bytes_to_copy = min(user_buf_len, pkt_len);
+
+                       /* sk_lock is held by caller so no one else can dequeue.
+                        * Unlock rx_lock since memcpy_to_msg() may sleep.
+                        */
+                       spin_unlock_bh(&vvs->rx_lock);
+
+                       if (memcpy_to_msg(msg, pkt->buf, bytes_to_copy)) {
+                               spin_lock_bh(&vvs->rx_lock);
+                               err = -EINVAL;
+                               break;
+                       }
+
+                       spin_lock_bh(&vvs->rx_lock);
+                       user_buf_len -= bytes_to_copy;
+                       bytes_handled += pkt->len;
+                       vvs->user_read_copied += bytes_to_copy;
+
+                       if (le16_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_RW_EOR)
+                               msg->msg_flags |= MSG_EOR;
+                       break;
+               }
+               default:
+                       ;
+               }
+
+               /* For unexpected 'SEQ_BEGIN', keep such packet in queue,
+                * but drop any other type of packet.
+                */
+               if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_SEQ_BEGIN) {
+                       virtio_transport_dec_rx_pkt(vvs, pkt->len);
+                       virtio_transport_del_n_free_pkt(pkt);
+               }
+       }
+
+       spin_unlock_bh(&vvs->rx_lock);
+
+       virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_SEQPACKET,
+                                           NULL);
+
+       return err ?: bytes_handled;
+}
+
ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk,
                                struct msghdr *msg,
                                size_t len, int flags)
{
-       if (flags & MSG_PEEK)
-               return virtio_transport_stream_do_peek(vsk, msg, len);
-       else
+       if (virtio_transport_get_type(sk_vsock(vsk)) == 
VIRTIO_VSOCK_TYPE_SEQPACKET) {
+               if (flags & MSG_PEEK)
+                       return -EOPNOTSUPP;
+
+               return virtio_transport_seqpacket_do_dequeue(vsk, msg, len);
+       } else {
+               if (flags & MSG_PEEK)
+ return virtio_transport_stream_do_peek(vsk, msg, len);
+
                return virtio_transport_stream_do_dequeue(vsk, msg, len);
+       }

Maybe is better to create two different functions here since we are using two complete different paths:
virtio_transport_stream_dequeue()
virtio_transport_seqpacket_dequeue()

And adding a new 'seqpacket_dequeue' callback in the transport.

}
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);

@@ -481,6 +673,8 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk,
        spin_lock_init(&vvs->rx_lock);
        spin_lock_init(&vvs->tx_lock);
        INIT_LIST_HEAD(&vvs->rx_queue);
+       vvs->user_read_copied = 0;
+       vvs->user_read_seq_len = 0;

        return 0;
}
@@ -490,13 +684,16 @@ EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
{
        struct virtio_vsock_sock *vvs = vsk->trans;
+       int type;

        if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
                *val = VIRTIO_VSOCK_MAX_BUF_SIZE;

        vvs->buf_alloc = *val;

-       virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
+       type = virtio_transport_get_type(sk_vsock(vsk));
+
+       virtio_transport_send_credit_update(vsk, type,
                                            NULL);

This line can now be merged with the previous one.

}
EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
@@ -624,10 +821,11 @@ int virtio_transport_connect(struct vsock_sock *vsk)
{
        struct virtio_vsock_pkt_info info = {
                .op = VIRTIO_VSOCK_OP_REQUEST,
-               .type = VIRTIO_VSOCK_TYPE_STREAM,
                .vsk = vsk,
        };

+       info.type = virtio_transport_get_type(sk_vsock(vsk));
+
        return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_connect);
@@ -636,7 +834,6 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int 
mode)
{
        struct virtio_vsock_pkt_info info = {
                .op = VIRTIO_VSOCK_OP_SHUTDOWN,
-               .type = VIRTIO_VSOCK_TYPE_STREAM,
                .flags = (mode & RCV_SHUTDOWN ?
                          VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
                         (mode & SEND_SHUTDOWN ?
@@ -644,6 +841,8 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int 
mode)
                .vsk = vsk,
        };

+       info.type = virtio_transport_get_type(sk_vsock(vsk));
+
        return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
@@ -665,12 +864,18 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
{
        struct virtio_vsock_pkt_info info = {
                .op = VIRTIO_VSOCK_OP_RW,
-               .type = VIRTIO_VSOCK_TYPE_STREAM,
                .msg = msg,
                .pkt_len = len,
                .vsk = vsk,
+               .flags = 0,
        };

+       info.type = virtio_transport_get_type(sk_vsock(vsk));
+
+       if (info.type == VIRTIO_VSOCK_TYPE_SEQPACKET &&
+           msg->msg_flags & MSG_EOR)
+               info.flags |= VIRTIO_VSOCK_RW_EOR;
+
        return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
@@ -688,7 +893,6 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
{
        struct virtio_vsock_pkt_info info = {
                .op = VIRTIO_VSOCK_OP_RST,
-               .type = VIRTIO_VSOCK_TYPE_STREAM,
                .reply = !!pkt,
                .vsk = vsk,
        };
@@ -697,6 +901,8 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
        if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
                return 0;

+       info.type = virtio_transport_get_type(sk_vsock(vsk));
+
        return virtio_transport_send_pkt_info(vsk, &info);
}

@@ -884,44 +1090,59 @@ virtio_transport_recv_connecting(struct sock *sk,
        return err;
}

-static void
+static bool
virtio_transport_recv_enqueue(struct vsock_sock *vsk,
                              struct virtio_vsock_pkt *pkt)
{
        struct virtio_vsock_sock *vvs = vsk->trans;
-       bool can_enqueue, free_pkt = false;
+       bool data_ready = false;
+       bool free_pkt = false;

-       pkt->len = le32_to_cpu(pkt->hdr.len);
        pkt->off = 0;
+       pkt->len = le32_to_cpu(pkt->hdr.len);

        spin_lock_bh(&vvs->rx_lock);

-       can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt);
-       if (!can_enqueue) {
-               free_pkt = true;
-               goto out;
-       }
+       switch (le32_to_cpu(pkt->hdr.type)) {
+       case VIRTIO_VSOCK_TYPE_STREAM: {
+               if (!virtio_transport_inc_rx_pkt(vvs, pkt)) {
+                       free_pkt = true;
+                       goto out;
+               }

-       /* Try to copy small packets into the buffer of last packet queued,
-        * to avoid wasting memory queueing the entire buffer with a small
-        * payload.
-        */
-       if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
-               struct virtio_vsock_pkt *last_pkt;
+               /* Try to copy small packets into the buffer of last packet 
queued,
+                * to avoid wasting memory queueing the entire buffer with a 
small
+                * payload.
+                */
+               if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
+                       struct virtio_vsock_pkt *last_pkt;

-               last_pkt = list_last_entry(&vvs->rx_queue,
-                                          struct virtio_vsock_pkt, list);
+                       last_pkt = list_last_entry(&vvs->rx_queue,
+                                                  struct virtio_vsock_pkt, 
list);

-               /* If there is space in the last packet queued, we copy the
-                * new packet in its buffer.
-                */
-               if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
-                       memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
-                              pkt->len);
-                       last_pkt->len += pkt->len;
-                       free_pkt = true;
-                       goto out;
+                       /* If there is space in the last packet queued, we copy 
the
+                        * new packet in its buffer.
+                        */
+                       if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
+                               memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
+                                      pkt->len);
+                               last_pkt->len += pkt->len;
+                               free_pkt = true;
+                               goto out;
+                       }
                }
+
+               data_ready = true;
+               break;
+       }
+
  ^
Unnecessary line break.

+       case VIRTIO_VSOCK_TYPE_SEQPACKET: {
+               data_ready = true;
+               vvs->rx_bytes += pkt->len;

Why not using virtio_transport_inc_rx_pkt()?

+               break;
+       }
+       default:
+               goto out;
        }

        list_add_tail(&pkt->list, &vvs->rx_queue);
@@ -930,6 +1151,8 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
        spin_unlock_bh(&vvs->rx_lock);
        if (free_pkt)
                virtio_transport_free_pkt(pkt);
+
+       return data_ready;
}

static int
@@ -940,9 +1163,17 @@ virtio_transport_recv_connected(struct sock *sk,
        int err = 0;

        switch (le16_to_cpu(pkt->hdr.op)) {
+       case VIRTIO_VSOCK_OP_SEQ_BEGIN: {
+               struct virtio_vsock_sock *vvs = vsk->trans;
+
+               spin_lock_bh(&vvs->rx_lock);
+               list_add_tail(&pkt->list, &vvs->rx_queue);
+               spin_unlock_bh(&vvs->rx_lock);
+               return err;
+       }
        case VIRTIO_VSOCK_OP_RW:
-               virtio_transport_recv_enqueue(vsk, pkt);
-               sk->sk_data_ready(sk);
+               if (virtio_transport_recv_enqueue(vsk, pkt))
+                       sk->sk_data_ready(sk);
                return err;
        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
                sk->sk_write_space(sk);
@@ -990,13 +1221,14 @@ virtio_transport_send_response(struct vsock_sock *vsk,
{
        struct virtio_vsock_pkt_info info = {
                .op = VIRTIO_VSOCK_OP_RESPONSE,
-               .type = VIRTIO_VSOCK_TYPE_STREAM,
                .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
                .remote_port = le32_to_cpu(pkt->hdr.src_port),
                .reply = true,
                .vsk = vsk,
        };

+       info.type = virtio_transport_get_type(sk_vsock(vsk));
+
        return virtio_transport_send_pkt_info(vsk, &info);
}

@@ -1086,6 +1318,12 @@ virtio_transport_recv_listen(struct sock *sk, struct 
virtio_vsock_pkt *pkt,
        return 0;
}

+static bool virtio_transport_valid_type(u16 type)
+{
+       return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
+              (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
+}
+
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
 * lock.
 */
@@ -1111,7 +1349,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
                                        le32_to_cpu(pkt->hdr.buf_alloc),
                                        le32_to_cpu(pkt->hdr.fwd_cnt));

-       if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
+       if (!virtio_transport_valid_type(le16_to_cpu(pkt->hdr.type))) {
                (void)virtio_transport_reset_no_sock(t, pkt);
                goto free_pkt;
        }
@@ -1128,6 +1366,9 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
                }
        }

+       if (virtio_transport_get_type(sk) != le16_to_cpu(pkt->hdr.type))
+               goto free_pkt;
+
        vsk = vsock_sk(sk);

        space_available = virtio_transport_space_update(sk, pkt);
--
2.25.1


Reply via email to