In tcp_recvmsg_locked(), detect if the skb being received by the user
is a devmem skb. In this case - if the user provided the MSG_SOCK_DEVMEM
flag - pass it to tcp_recvmsg_devmem() for custom handling.

tcp_recvmsg_devmem() copies any data in the skb header to the linear
buffer, and returns a cmsg to the user indicating the number of bytes
returned in the linear buffer.

tcp_recvmsg_devmem() then loops over the unaccessible devmem skb frags,
and returns to the user a cmsg_devmem indicating the location of the
data in the dmabuf device memory. cmsg_devmem contains this information:

1. the offset into the dmabuf where the payload starts. 'frag_offset'.
2. the size of the frag. 'frag_size'.
3. an opaque token 'frag_token' to return to the kernel when the buffer
is to be released.

The pages awaiting freeing are stored in the newly added
sk->sk_user_frags, and each page passed to userspace is get_page()'d.
This reference is dropped once the userspace indicates that it is
done reading this page.  All pages are released when the socket is
destroyed.

Signed-off-by: Willem de Bruijn <will...@google.com>
Signed-off-by: Kaiyuan Zhang <kaiyu...@google.com>
Signed-off-by: Mina Almasry <almasrym...@google.com>
Reviewed-by: Pavel Begunkov <asml.sile...@gmail.com>
Reviewed-by: Eric Dumazet <eduma...@google.com>

---

v20:
- Change `offset = 0` to `offset = offset - start` to resolve issue
  reported by Taehee.

v16:
- Fix number assignement (Arnd).

v13:
- Refactored user frags cleanup into a common function to avoid
  __maybe_unused. (Pavel)
- change to offset = 0 for some improved clarity.

v11:
- Refactor to common function te remove conditional lock sparse warning
  (Paolo)

v7:
- Updated the SO_DEVMEM_* uapi to use the next available entries (Arnd).
- Updated dmabuf_cmsg struct to be __u64 padded (Arnd).
- Squashed fix from Eric to initialize sk_user_frags for passive
  sockets (Eric).

v6
- skb->dmabuf -> skb->readable (Pavel)
- Fixed asm definitions of SO_DEVMEM_LINEAR/SO_DEVMEM_DMABUF not found
  on some archs.
- Squashed in locking optimizations from eduma...@google.com. With this
  change we lock the xarray once per per tcp_recvmsg_dmabuf() rather
  than once per frag in xa_alloc().

Changes in v1:
- Added dmabuf_id to dmabuf_cmsg (David/Stan).
- Devmem -> dmabuf (David).
- Change tcp_recvmsg_dmabuf() check to skb->dmabuf (Paolo).
- Use __skb_frag_ref() & napi_pp_put_page() for refcounting (Yunsheng).

RFC v3:
- Fixed issue with put_cmsg() failing silently.

---
 arch/alpha/include/uapi/asm/socket.h  |   5 +
 arch/mips/include/uapi/asm/socket.h   |   5 +
 arch/parisc/include/uapi/asm/socket.h |   5 +
 arch/sparc/include/uapi/asm/socket.h  |   5 +
 include/linux/socket.h                |   1 +
 include/net/netmem.h                  |  13 ++
 include/net/sock.h                    |   2 +
 include/uapi/asm-generic/socket.h     |   5 +
 include/uapi/linux/uio.h              |  13 ++
 net/ipv4/tcp.c                        | 255 +++++++++++++++++++++++++-
 net/ipv4/tcp_ipv4.c                   |  16 ++
 net/ipv4/tcp_minisocks.c              |   2 +
 12 files changed, 322 insertions(+), 5 deletions(-)

diff --git a/arch/alpha/include/uapi/asm/socket.h 
b/arch/alpha/include/uapi/asm/socket.h
index e94f621903fe..ef4656a41058 100644
--- a/arch/alpha/include/uapi/asm/socket.h
+++ b/arch/alpha/include/uapi/asm/socket.h
@@ -140,6 +140,11 @@
 #define SO_PASSPIDFD           76
 #define SO_PEERPIDFD           77
 
+#define SO_DEVMEM_LINEAR       78
+#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF       79
+#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
diff --git a/arch/mips/include/uapi/asm/socket.h 
b/arch/mips/include/uapi/asm/socket.h
index 60ebaed28a4c..414807d55e33 100644
--- a/arch/mips/include/uapi/asm/socket.h
+++ b/arch/mips/include/uapi/asm/socket.h
@@ -151,6 +151,11 @@
 #define SO_PASSPIDFD           76
 #define SO_PEERPIDFD           77
 
+#define SO_DEVMEM_LINEAR       78
+#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF       79
+#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
diff --git a/arch/parisc/include/uapi/asm/socket.h 
b/arch/parisc/include/uapi/asm/socket.h
index be264c2b1a11..2b817efd4544 100644
--- a/arch/parisc/include/uapi/asm/socket.h
+++ b/arch/parisc/include/uapi/asm/socket.h
@@ -132,6 +132,11 @@
 #define SO_PASSPIDFD           0x404A
 #define SO_PEERPIDFD           0x404B
 
+#define SO_DEVMEM_LINEAR       78
+#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF       79
+#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
diff --git a/arch/sparc/include/uapi/asm/socket.h 
b/arch/sparc/include/uapi/asm/socket.h
index 682da3714686..00248fc68977 100644
--- a/arch/sparc/include/uapi/asm/socket.h
+++ b/arch/sparc/include/uapi/asm/socket.h
@@ -133,6 +133,11 @@
 #define SO_PASSPIDFD             0x0055
 #define SO_PEERPIDFD             0x0056
 
+#define SO_DEVMEM_LINEAR         0x0057
+#define SCM_DEVMEM_LINEAR        SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF         0x0058
+#define SCM_DEVMEM_DMABUF        SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 
diff --git a/include/linux/socket.h b/include/linux/socket.h
index df9cdb8bbfb8..d18cc47e89bd 100644
--- a/include/linux/socket.h
+++ b/include/linux/socket.h
@@ -327,6 +327,7 @@ struct ucred {
                                          * plain text and require encryption
                                          */
 
+#define MSG_SOCK_DEVMEM 0x2000000      /* Receive devmem skbs as cmsg */
 #define MSG_ZEROCOPY   0x4000000       /* Use user data in kernel path */
 #define MSG_SPLICE_PAGES 0x8000000     /* Splice the pages from the iterator 
in sendmsg() */
 #define MSG_FASTOPEN   0x20000000      /* Send data in TCP SYN */
diff --git a/include/net/netmem.h b/include/net/netmem.h
index 61400d4b0d66..209013aaea18 100644
--- a/include/net/netmem.h
+++ b/include/net/netmem.h
@@ -66,6 +66,19 @@ static inline unsigned int net_iov_idx(const struct net_iov 
*niov)
        return niov - net_iov_owner(niov)->niovs;
 }
 
+static inline unsigned long net_iov_virtual_addr(const struct net_iov *niov)
+{
+       struct dmabuf_genpool_chunk_owner *owner = net_iov_owner(niov);
+
+       return owner->base_virtual +
+              ((unsigned long)net_iov_idx(niov) << PAGE_SHIFT);
+}
+
+static inline u32 net_iov_binding_id(const struct net_iov *niov)
+{
+       return net_iov_owner(niov)->binding->id;
+}
+
 static inline struct net_devmem_dmabuf_binding *
 net_iov_binding(const struct net_iov *niov)
 {
diff --git a/include/net/sock.h b/include/net/sock.h
index f51d61fab059..c58ca8dd561b 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -337,6 +337,7 @@ struct sk_filter;
   *    @sk_txtime_report_errors: set report errors mode for SO_TXTIME
   *    @sk_txtime_unused: unused txtime flags
   *    @ns_tracker: tracker for netns reference
+  *    @sk_user_frags: xarray of pages the user is holding a reference on.
   */
 struct sock {
        /*
@@ -542,6 +543,7 @@ struct sock {
 #endif
        struct rcu_head         sk_rcu;
        netns_tracker           ns_tracker;
+       struct xarray           sk_user_frags;
 };
 
 struct sock_bh_locked {
diff --git a/include/uapi/asm-generic/socket.h 
b/include/uapi/asm-generic/socket.h
index 8ce8a39a1e5f..e993edc9c0ee 100644
--- a/include/uapi/asm-generic/socket.h
+++ b/include/uapi/asm-generic/socket.h
@@ -135,6 +135,11 @@
 #define SO_PASSPIDFD           76
 #define SO_PEERPIDFD           77
 
+#define SO_DEVMEM_LINEAR       78
+#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF       79
+#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
diff --git a/include/uapi/linux/uio.h b/include/uapi/linux/uio.h
index 059b1a9147f4..3a22ddae376a 100644
--- a/include/uapi/linux/uio.h
+++ b/include/uapi/linux/uio.h
@@ -20,6 +20,19 @@ struct iovec
        __kernel_size_t iov_len; /* Must be size_t (1003.1g) */
 };
 
+struct dmabuf_cmsg {
+       __u64 frag_offset;      /* offset into the dmabuf where the frag starts.
+                                */
+       __u32 frag_size;        /* size of the frag. */
+       __u32 frag_token;       /* token representing this frag for
+                                * DEVMEM_DONTNEED.
+                                */
+       __u32  dmabuf_id;       /* dmabuf id this frag belongs to. */
+       __u32 flags;            /* Currently unused. Reserved for future
+                                * uses.
+                                */
+};
+
 /*
  *     UIO_MAXIOV shall be at least 16 1003.1g (5.4.1.1)
  */
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 30238963fe99..93cb76813bf1 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -471,6 +471,7 @@ void tcp_init_sock(struct sock *sk)
 
        set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags);
        sk_sockets_allocated_inc(sk);
+       xa_init_flags(&sk->sk_user_frags, XA_FLAGS_ALLOC1);
 }
 EXPORT_SYMBOL(tcp_init_sock);
 
@@ -2323,6 +2324,220 @@ static int tcp_inq_hint(struct sock *sk)
        return inq;
 }
 
+/* batch __xa_alloc() calls and reduce xa_lock()/xa_unlock() overhead. */
+struct tcp_xa_pool {
+       u8              max; /* max <= MAX_SKB_FRAGS */
+       u8              idx; /* idx <= max */
+       __u32           tokens[MAX_SKB_FRAGS];
+       netmem_ref      netmems[MAX_SKB_FRAGS];
+};
+
+static void tcp_xa_pool_commit_locked(struct sock *sk, struct tcp_xa_pool *p)
+{
+       int i;
+
+       /* Commit part that has been copied to user space. */
+       for (i = 0; i < p->idx; i++)
+               __xa_cmpxchg(&sk->sk_user_frags, p->tokens[i], XA_ZERO_ENTRY,
+                            (__force void *)p->netmems[i], GFP_KERNEL);
+       /* Rollback what has been pre-allocated and is no longer needed. */
+       for (; i < p->max; i++)
+               __xa_erase(&sk->sk_user_frags, p->tokens[i]);
+
+       p->max = 0;
+       p->idx = 0;
+}
+
+static void tcp_xa_pool_commit(struct sock *sk, struct tcp_xa_pool *p)
+{
+       if (!p->max)
+               return;
+
+       xa_lock_bh(&sk->sk_user_frags);
+
+       tcp_xa_pool_commit_locked(sk, p);
+
+       xa_unlock_bh(&sk->sk_user_frags);
+}
+
+static int tcp_xa_pool_refill(struct sock *sk, struct tcp_xa_pool *p,
+                             unsigned int max_frags)
+{
+       int err, k;
+
+       if (p->idx < p->max)
+               return 0;
+
+       xa_lock_bh(&sk->sk_user_frags);
+
+       tcp_xa_pool_commit_locked(sk, p);
+
+       for (k = 0; k < max_frags; k++) {
+               err = __xa_alloc(&sk->sk_user_frags, &p->tokens[k],
+                                XA_ZERO_ENTRY, xa_limit_31b, GFP_KERNEL);
+               if (err)
+                       break;
+       }
+
+       xa_unlock_bh(&sk->sk_user_frags);
+
+       p->max = k;
+       p->idx = 0;
+       return k ? 0 : err;
+}
+
+/* On error, returns the -errno. On success, returns number of bytes sent to 
the
+ * user. May not consume all of @remaining_len.
+ */
+static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb,
+                             unsigned int offset, struct msghdr *msg,
+                             int remaining_len)
+{
+       struct dmabuf_cmsg dmabuf_cmsg = { 0 };
+       struct tcp_xa_pool tcp_xa_pool;
+       unsigned int start;
+       int i, copy, n;
+       int sent = 0;
+       int err = 0;
+
+       tcp_xa_pool.max = 0;
+       tcp_xa_pool.idx = 0;
+       do {
+               start = skb_headlen(skb);
+
+               if (skb_frags_readable(skb)) {
+                       err = -ENODEV;
+                       goto out;
+               }
+
+               /* Copy header. */
+               copy = start - offset;
+               if (copy > 0) {
+                       copy = min(copy, remaining_len);
+
+                       n = copy_to_iter(skb->data + offset, copy,
+                                        &msg->msg_iter);
+                       if (n != copy) {
+                               err = -EFAULT;
+                               goto out;
+                       }
+
+                       offset += copy;
+                       remaining_len -= copy;
+
+                       /* First a dmabuf_cmsg for # bytes copied to user
+                        * buffer.
+                        */
+                       memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg));
+                       dmabuf_cmsg.frag_size = copy;
+                       err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_LINEAR,
+                                      sizeof(dmabuf_cmsg), &dmabuf_cmsg);
+                       if (err || msg->msg_flags & MSG_CTRUNC) {
+                               msg->msg_flags &= ~MSG_CTRUNC;
+                               if (!err)
+                                       err = -ETOOSMALL;
+                               goto out;
+                       }
+
+                       sent += copy;
+
+                       if (remaining_len == 0)
+                               goto out;
+               }
+
+               /* after that, send information of dmabuf pages through a
+                * sequence of cmsg
+                */
+               for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+                       skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
+                       struct net_iov *niov;
+                       u64 frag_offset;
+                       int end;
+
+                       /* !skb_frags_readable() should indicate that ALL the
+                        * frags in this skb are dmabuf net_iovs. We're checking
+                        * for that flag above, but also check individual frags
+                        * here. If the tcp stack is not setting
+                        * skb_frags_readable() correctly, we still don't want
+                        * to crash here.
+                        */
+                       if (!skb_frag_net_iov(frag)) {
+                               net_err_ratelimited("Found non-dmabuf skb with 
net_iov");
+                               err = -ENODEV;
+                               goto out;
+                       }
+
+                       niov = skb_frag_net_iov(frag);
+                       end = start + skb_frag_size(frag);
+                       copy = end - offset;
+
+                       if (copy > 0) {
+                               copy = min(copy, remaining_len);
+
+                               frag_offset = net_iov_virtual_addr(niov) +
+                                             skb_frag_off(frag) + offset -
+                                             start;
+                               dmabuf_cmsg.frag_offset = frag_offset;
+                               dmabuf_cmsg.frag_size = copy;
+                               err = tcp_xa_pool_refill(sk, &tcp_xa_pool,
+                                                        
skb_shinfo(skb)->nr_frags - i);
+                               if (err)
+                                       goto out;
+
+                               /* Will perform the exchange later */
+                               dmabuf_cmsg.frag_token = 
tcp_xa_pool.tokens[tcp_xa_pool.idx];
+                               dmabuf_cmsg.dmabuf_id = 
net_iov_binding_id(niov);
+
+                               offset += copy;
+                               remaining_len -= copy;
+
+                               err = put_cmsg(msg, SOL_SOCKET,
+                                              SO_DEVMEM_DMABUF,
+                                              sizeof(dmabuf_cmsg),
+                                              &dmabuf_cmsg);
+                               if (err || msg->msg_flags & MSG_CTRUNC) {
+                                       msg->msg_flags &= ~MSG_CTRUNC;
+                                       if (!err)
+                                               err = -ETOOSMALL;
+                                       goto out;
+                               }
+
+                               atomic_long_inc(&niov->pp_ref_count);
+                               tcp_xa_pool.netmems[tcp_xa_pool.idx++] = 
skb_frag_netmem(frag);
+
+                               sent += copy;
+
+                               if (remaining_len == 0)
+                                       goto out;
+                       }
+                       start = end;
+               }
+
+               tcp_xa_pool_commit(sk, &tcp_xa_pool);
+               if (!remaining_len)
+                       goto out;
+
+               /* if remaining_len is not satisfied yet, we need to go to the
+                * next frag in the frag_list to satisfy remaining_len.
+                */
+               skb = skb_shinfo(skb)->frag_list ?: skb->next;
+
+               offset = offset - start;
+       } while (skb);
+
+       if (remaining_len) {
+               err = -EFAULT;
+               goto out;
+       }
+
+out:
+       tcp_xa_pool_commit(sk, &tcp_xa_pool);
+       if (!sent)
+               sent = err;
+
+       return sent;
+}
+
 /*
  *     This routine copies from a sock struct into the user buffer.
  *
@@ -2336,6 +2551,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct 
msghdr *msg, size_t len,
                              int *cmsg_flags)
 {
        struct tcp_sock *tp = tcp_sk(sk);
+       int last_copied_dmabuf = -1; /* uninitialized */
        int copied = 0;
        u32 peek_seq;
        u32 *seq;
@@ -2515,15 +2731,44 @@ static int tcp_recvmsg_locked(struct sock *sk, struct 
msghdr *msg, size_t len,
                }
 
                if (!(flags & MSG_TRUNC)) {
-                       err = skb_copy_datagram_msg(skb, offset, msg, used);
-                       if (err) {
-                               /* Exception. Bailout! */
-                               if (!copied)
-                                       copied = -EFAULT;
+                       if (last_copied_dmabuf != -1 &&
+                           last_copied_dmabuf != !skb_frags_readable(skb))
                                break;
+
+                       if (skb_frags_readable(skb)) {
+                               err = skb_copy_datagram_msg(skb, offset, msg,
+                                                           used);
+                               if (err) {
+                                       /* Exception. Bailout! */
+                                       if (!copied)
+                                               copied = -EFAULT;
+                                       break;
+                               }
+                       } else {
+                               if (!(flags & MSG_SOCK_DEVMEM)) {
+                                       /* dmabuf skbs can only be received
+                                        * with the MSG_SOCK_DEVMEM flag.
+                                        */
+                                       if (!copied)
+                                               copied = -EFAULT;
+
+                                       break;
+                               }
+
+                               err = tcp_recvmsg_dmabuf(sk, skb, offset, msg,
+                                                        used);
+                               if (err <= 0) {
+                                       if (!copied)
+                                               copied = -EFAULT;
+
+                                       break;
+                               }
+                               used = err;
                        }
                }
 
+               last_copied_dmabuf = !skb_frags_readable(skb);
+
                WRITE_ONCE(*seq, *seq + used);
                copied += used;
                len -= used;
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index eb631e66ee03..5afe5e57c89b 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -79,6 +79,7 @@
 #include <linux/seq_file.h>
 #include <linux/inetdevice.h>
 #include <linux/btf_ids.h>
+#include <linux/skbuff_ref.h>
 
 #include <crypto/hash.h>
 #include <linux/scatterlist.h>
@@ -2512,10 +2513,25 @@ static void tcp_md5sig_info_free_rcu(struct rcu_head 
*head)
 }
 #endif
 
+static void tcp_release_user_frags(struct sock *sk)
+{
+#ifdef CONFIG_PAGE_POOL
+       unsigned long index;
+       void *netmem;
+
+       xa_for_each(&sk->sk_user_frags, index, netmem)
+               WARN_ON_ONCE(!napi_pp_put_page((__force netmem_ref)netmem));
+#endif
+}
+
 void tcp_v4_destroy_sock(struct sock *sk)
 {
        struct tcp_sock *tp = tcp_sk(sk);
 
+       tcp_release_user_frags(sk);
+
+       xa_destroy(&sk->sk_user_frags);
+
        trace_tcp_destroy_sock(sk);
 
        tcp_clear_xmit_timers(sk);
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index ad562272db2e..bb1fe1ba867a 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -628,6 +628,8 @@ struct sock *tcp_create_openreq_child(const struct sock *sk,
 
        __TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS);
 
+       xa_init_flags(&newsk->sk_user_frags, XA_FLAGS_ALLOC1);
+
        return newsk;
 }
 EXPORT_SYMBOL(tcp_create_openreq_child);
-- 
2.46.0.469.g59c65b2a67-goog


Reply via email to