Add tests that do a MSG_PEEK recv followed by a regular receive to
test flag support.

Signed-off-by: John Fastabend <john.fastab...@gmail.com>
---
 tools/testing/selftests/bpf/test_sockmap.c |  167 +++++++++++++++++++---------
 1 file changed, 115 insertions(+), 52 deletions(-)

diff --git a/tools/testing/selftests/bpf/test_sockmap.c 
b/tools/testing/selftests/bpf/test_sockmap.c
index 7cb69ce..cbd1c0b 100644
--- a/tools/testing/selftests/bpf/test_sockmap.c
+++ b/tools/testing/selftests/bpf/test_sockmap.c
@@ -80,6 +80,7 @@
 int txmsg_ingress;
 int txmsg_skb;
 int ktls;
+int peek_flag;
 
 static const struct option long_options[] = {
        {"help",        no_argument,            NULL, 'h' },
@@ -102,6 +103,7 @@
        {"txmsg_ingress", no_argument,          &txmsg_ingress, 1 },
        {"txmsg_skb", no_argument,              &txmsg_skb, 1 },
        {"ktls", no_argument,                   &ktls, 1 },
+       {"peek", no_argument,                   &peek_flag, 1 },
        {0, 0, NULL, 0 }
 };
 
@@ -352,33 +354,40 @@ static int msg_loop_sendpage(int fd, int iov_length, int 
cnt,
        return 0;
 }
 
-static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
-                   struct msg_stats *s, bool tx,
-                   struct sockmap_options *opt)
+static void msg_free_iov(struct msghdr *msg)
 {
-       struct msghdr msg = {0};
-       int err, i, flags = MSG_NOSIGNAL;
+       int i;
+
+       for (i = 0; i < msg->msg_iovlen; i++)
+               free(msg->msg_iov[i].iov_base);
+       free(msg->msg_iov);
+       msg->msg_iov = NULL;
+       msg->msg_iovlen = 0;
+}
+
+static int msg_alloc_iov(struct msghdr *msg,
+                        int iov_count, int iov_length,
+                        bool data, bool xmit)
+{
+       unsigned char k = 0;
        struct iovec *iov;
-       unsigned char k;
-       bool data_test = opt->data_test;
-       bool drop = opt->drop_expected;
+       int i;
 
        iov = calloc(iov_count, sizeof(struct iovec));
        if (!iov)
                return errno;
 
-       k = 0;
        for (i = 0; i < iov_count; i++) {
                unsigned char *d = calloc(iov_length, sizeof(char));
 
                if (!d) {
                        fprintf(stderr, "iov_count %i/%i OOM\n", i, iov_count);
-                       goto out_errno;
+                       goto unwind_iov;
                }
                iov[i].iov_base = d;
                iov[i].iov_len = iov_length;
 
-               if (data_test && tx) {
+               if (data && xmit) {
                        int j;
 
                        for (j = 0; j < iov_length; j++)
@@ -386,9 +395,60 @@ static int msg_loop(int fd, int iov_count, int iov_length, 
int cnt,
                }
        }
 
-       msg.msg_iov = iov;
-       msg.msg_iovlen = iov_count;
-       k = 0;
+       msg->msg_iov = iov;
+       msg->msg_iovlen = iov_count;
+
+       return 0;
+unwind_iov:
+       for (i--; i >= 0 ; i--)
+               free(msg->msg_iov[i].iov_base);
+       return -ENOMEM;
+}
+
+static int msg_verify_data(struct msghdr *msg, int size, int chunk_sz)
+{
+       int i, j, bytes_cnt = 0;
+       unsigned char k = 0;
+
+       for (i = 0; i < msg->msg_iovlen; i++) {
+               unsigned char *d = msg->msg_iov[i].iov_base;
+
+               for (j = 0;
+                    j < msg->msg_iov[i].iov_len && size; j++) {
+                       if (d[j] != k++) {
+                               fprintf(stderr,
+                                       "detected data corruption @iov[%i]:%i 
%02x != %02x, %02x ?= %02x\n",
+                                       i, j, d[j], k - 1, d[j+1], k);
+                               return -EIO;
+                       }
+                       bytes_cnt++;
+                       if (bytes_cnt == chunk_sz) {
+                               k = 0;
+                               bytes_cnt = 0;
+                       }
+                       size--;
+               }
+       }
+       return 0;
+}
+
+static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
+                   struct msg_stats *s, bool tx,
+                   struct sockmap_options *opt)
+{
+       struct msghdr msg = {0}, msg_peek = {0};
+       int err, i, flags = MSG_NOSIGNAL;
+       bool drop = opt->drop_expected;
+       bool data = opt->data_test;
+
+       err = msg_alloc_iov(&msg, iov_count, iov_length, data, tx);
+       if (err)
+               goto out_errno;
+       if (peek_flag) {
+               err = msg_alloc_iov(&msg_peek, iov_count, iov_length, data, tx);
+               if (err)
+                       goto out_errno;
+       }
 
        if (tx) {
                clock_gettime(CLOCK_MONOTONIC, &s->start);
@@ -408,19 +468,12 @@ static int msg_loop(int fd, int iov_count, int 
iov_length, int cnt,
                }
                clock_gettime(CLOCK_MONOTONIC, &s->end);
        } else {
-               int slct, recv, max_fd = fd;
+               int slct, recvp = 0, recv, max_fd = fd;
                int fd_flags = O_NONBLOCK;
                struct timeval timeout;
                float total_bytes;
-               int bytes_cnt = 0;
-               int chunk_sz;
                fd_set w;
 
-               if (opt->sendpage)
-                       chunk_sz = iov_length * cnt;
-               else
-                       chunk_sz = iov_length * iov_count;
-
                fcntl(fd, fd_flags);
                total_bytes = (float)iov_count * (float)iov_length * (float)cnt;
                err = clock_gettime(CLOCK_MONOTONIC, &s->start);
@@ -452,6 +505,19 @@ static int msg_loop(int fd, int iov_count, int iov_length, 
int cnt,
                                goto out_errno;
                        }
 
+                       errno = 0;
+                       if (peek_flag) {
+                               flags |= MSG_PEEK;
+                               recvp = recvmsg(fd, &msg_peek, flags);
+                               if (recvp < 0) {
+                                       if (errno != EWOULDBLOCK) {
+                                               clock_gettime(CLOCK_MONOTONIC, 
&s->end);
+                                               goto out_errno;
+                                       }
+                               }
+                               flags = 0;
+                       }
+
                        recv = recvmsg(fd, &msg, flags);
                        if (recv < 0) {
                                if (errno != EWOULDBLOCK) {
@@ -463,27 +529,23 @@ static int msg_loop(int fd, int iov_count, int 
iov_length, int cnt,
 
                        s->bytes_recvd += recv;
 
-                       if (data_test) {
-                               int j;
-
-                               for (i = 0; i < msg.msg_iovlen; i++) {
-                                       unsigned char *d = iov[i].iov_base;
-
-                                       for (j = 0;
-                                            j < iov[i].iov_len && recv; j++) {
-                                               if (d[j] != k++) {
-                                                       errno = -EIO;
-                                                       fprintf(stderr,
-                                                               "detected data 
corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n",
-                                                               i, j, d[j], k - 
1, d[j+1], k);
-                                                       goto out_errno;
-                                               }
-                                               bytes_cnt++;
-                                               if (bytes_cnt == chunk_sz) {
-                                                       k = 0;
-                                                       bytes_cnt = 0;
-                                               }
-                                               recv--;
+                       if (data) {
+                               int chunk_sz = opt->sendpage ?
+                                               iov_length * cnt :
+                                               iov_length * iov_count;
+
+                               errno = msg_verify_data(&msg, recv, chunk_sz);
+                               if (errno) {
+                                       perror("data verify msg failed\n");
+                                       goto out_errno;
+                               }
+                               if (recvp) {
+                                       errno = msg_verify_data(&msg_peek,
+                                                               recvp,
+                                                               chunk_sz);
+                                       if (errno) {
+                                               perror("data verify msg_peek 
failed\n");
+                                               goto out_errno;
                                        }
                                }
                        }
@@ -491,14 +553,12 @@ static int msg_loop(int fd, int iov_count, int 
iov_length, int cnt,
                clock_gettime(CLOCK_MONOTONIC, &s->end);
        }
 
-       for (i = 0; i < iov_count; i++)
-               free(iov[i].iov_base);
-       free(iov);
-       return 0;
+       msg_free_iov(&msg);
+       msg_free_iov(&msg_peek);
+       return err;
 out_errno:
-       for (i = 0; i < iov_count; i++)
-               free(iov[i].iov_base);
-       free(iov);
+       msg_free_iov(&msg);
+       msg_free_iov(&msg_peek);
        return errno;
 }
 
@@ -565,9 +625,10 @@ static int sendmsg_test(struct sockmap_options *opt)
                }
                if (opt->verbose)
                        fprintf(stdout,
-                               "rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB 
%fB/s %fGB/s\n",
+                               "rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB 
%fB/s %fGB/s %s\n",
                                s.bytes_sent, sent_Bps, sent_Bps/giga,
-                               s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
+                               s.bytes_recvd, recvd_Bps, recvd_Bps/giga,
+                               peek_flag ? "(peek_msg)" : "");
                if (err && txmsg_cork)
                        err = 0;
                exit(err ? 1 : 0);
@@ -999,6 +1060,8 @@ static void test_options(char *options)
                strncat(options, "skb,", OPTSTRING);
        if (ktls)
                strncat(options, "ktls,", OPTSTRING);
+       if (peek_flag)
+               strncat(options, "peek,", OPTSTRING);
 }
 
 static int __test_exec(int cgrp, int test, struct sockmap_options *opt)

Reply via email to