Move the call of skb_cow_data() from rxkad into rxrpc_recvmsg_data() and do
it as soon as the packet is first seen.  This means that we only call this
function once per packet, even for a jumbo packet with a bunch of
subpackets.

In rxkad, we then have to guess how large a scatter-gather table we need
for decryption, particularly in rxkad_verify_packet_2().  We do this either
by creating an sg table that should be large enough, or by looking at
nr_frags on the skb.

Fixes: 17926a79320a ("[AF_RXRPC]: Provide secure RxRPC sockets for use by 
userspace and kernel both")
Signed-off-by: David Howells <[email protected]>
---

 net/rxrpc/ar-internal.h |    1 +
 net/rxrpc/input.c       |    1 +
 net/rxrpc/recvmsg.c     |   11 ++++++++++-
 net/rxrpc/rxkad.c       |   32 +++++++++-----------------------
 4 files changed, 21 insertions(+), 24 deletions(-)

diff --git a/net/rxrpc/ar-internal.h b/net/rxrpc/ar-internal.h
index d784d58e0a0d..a42d6b833675 100644
--- a/net/rxrpc/ar-internal.h
+++ b/net/rxrpc/ar-internal.h
@@ -190,6 +190,7 @@ struct rxrpc_skb_priv {
        u8              rx_flags;               /* Received packet flags */
 #define RXRPC_SKB_INCL_LAST    0x01            /* - Includes last packet */
 #define RXRPC_SKB_TX_BUFFER    0x02            /* - Is transmit buffer */
+#define RXRPC_SKB_NEEDS_COW    0x04            /* - Needs skb_cow_data() 
calling */
        union {
                int             remain;         /* amount of space remaining 
for next write */
 
diff --git a/net/rxrpc/input.c b/net/rxrpc/input.c
index 660b7eed39b7..4df39f391e9d 100644
--- a/net/rxrpc/input.c
+++ b/net/rxrpc/input.c
@@ -448,6 +448,7 @@ static void rxrpc_input_data(struct rxrpc_call *call, 
struct sk_buff *skb)
        }
 
        atomic_set(&sp->nr_ring_pins, 1);
+       sp->rx_flags |= RXRPC_SKB_NEEDS_COW;
 
        if (call->state == RXRPC_CALL_SERVER_RECV_REQUEST) {
                unsigned long timo = READ_ONCE(call->next_req_timo);
diff --git a/net/rxrpc/recvmsg.c b/net/rxrpc/recvmsg.c
index 82bb48d96526..ef50580b5295 100644
--- a/net/rxrpc/recvmsg.c
+++ b/net/rxrpc/recvmsg.c
@@ -305,7 +305,7 @@ static int rxrpc_recvmsg_data(struct socket *sock, struct 
rxrpc_call *call,
                              size_t len, int flags, size_t *_offset)
 {
        struct rxrpc_skb_priv *sp;
-       struct sk_buff *skb;
+       struct sk_buff *skb, *trailer;
        rxrpc_serial_t serial;
        rxrpc_seq_t hard_ack, top, seq;
        size_t remain;
@@ -343,6 +343,15 @@ static int rxrpc_recvmsg_data(struct socket *sock, struct 
rxrpc_call *call,
                rxrpc_see_skb(skb, rxrpc_skb_seen);
                sp = rxrpc_skb(skb);
 
+               if (sp->rx_flags & RXRPC_SKB_NEEDS_COW) {
+                       ret2 = skb_cow_data(skb, 0, &trailer);
+                       if (ret2 < 0) {
+                               ret = ret2;
+                               goto out;
+                       }
+                       sp->rx_flags &= ~RXRPC_SKB_NEEDS_COW;
+               }
+
                if (!(flags & MSG_PEEK)) {
                        serial = sp->hdr.serial;
                        serial += call->rxtx_annotations[ix] & 
RXRPC_RX_ANNO_SUBPACKET;
diff --git a/net/rxrpc/rxkad.c b/net/rxrpc/rxkad.c
index ae8cd8926456..c60c520fde7c 100644
--- a/net/rxrpc/rxkad.c
+++ b/net/rxrpc/rxkad.c
@@ -187,10 +187,8 @@ static int rxkad_secure_packet_encrypt(const struct 
rxrpc_call *call,
        struct rxrpc_skb_priv *sp;
        struct rxrpc_crypt iv;
        struct scatterlist sg[16];
-       struct sk_buff *trailer;
        unsigned int len;
        u16 check;
-       int nsg;
        int err;
 
        sp = rxrpc_skb(skb);
@@ -214,15 +212,14 @@ static int rxkad_secure_packet_encrypt(const struct 
rxrpc_call *call,
        crypto_skcipher_encrypt(req);
 
        /* we want to encrypt the skbuff in-place */
-       nsg = skb_cow_data(skb, 0, &trailer);
-       err = -ENOMEM;
-       if (nsg < 0 || nsg > 16)
+       err = -EMSGSIZE;
+       if (skb_shinfo(skb)->nr_frags > 16)
                goto out;
 
        len = data_size + call->conn->size_align - 1;
        len &= ~(call->conn->size_align - 1);
 
-       sg_init_table(sg, nsg);
+       sg_init_table(sg, ARRAY_SIZE(sg));
        err = skb_to_sgvec(skb, sg, 0, len);
        if (unlikely(err < 0))
                goto out;
@@ -319,11 +316,10 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, 
struct sk_buff *skb,
        struct rxkad_level1_hdr sechdr;
        struct rxrpc_crypt iv;
        struct scatterlist sg[16];
-       struct sk_buff *trailer;
        bool aborted;
        u32 data_size, buf;
        u16 check;
-       int nsg, ret;
+       int ret;
 
        _enter("");
 
@@ -336,11 +332,7 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, 
struct sk_buff *skb,
        /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
         * directly into the target buffer.
         */
-       nsg = skb_cow_data(skb, 0, &trailer);
-       if (nsg < 0 || nsg > 16)
-               goto nomem;
-
-       sg_init_table(sg, nsg);
+       sg_init_table(sg, ARRAY_SIZE(sg));
        ret = skb_to_sgvec(skb, sg, offset, 8);
        if (unlikely(ret < 0))
                return ret;
@@ -388,10 +380,6 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, 
struct sk_buff *skb,
        if (aborted)
                rxrpc_send_abort_packet(call);
        return -EPROTO;
-
-nomem:
-       _leave(" = -ENOMEM");
-       return -ENOMEM;
 }
 
 /*
@@ -406,7 +394,6 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, 
struct sk_buff *skb,
        struct rxkad_level2_hdr sechdr;
        struct rxrpc_crypt iv;
        struct scatterlist _sg[4], *sg;
-       struct sk_buff *trailer;
        bool aborted;
        u32 data_size, buf;
        u16 check;
@@ -423,12 +410,11 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, 
struct sk_buff *skb,
        /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
         * directly into the target buffer.
         */
-       nsg = skb_cow_data(skb, 0, &trailer);
-       if (nsg < 0)
-               goto nomem;
-
        sg = _sg;
-       if (unlikely(nsg > 4)) {
+       nsg = skb_shinfo(skb)->nr_frags;
+       if (nsg <= 4) {
+               nsg = 4;
+       } else {
                sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO);
                if (!sg)
                        goto nomem;

Reply via email to