From: John Fastabend <john.fastab...@gmail.com>

[ Upstream commit 84f44df664e9f0e261157e16ee1acd77cc1bb78d ]

Similar to patch ("bpf: sock_ops ctx access may stomp registers") if the
src_reg = dst_reg when reading the sk field of a sock_ops struct we
generate xlated code,

  53: (61) r9 = *(u32 *)(r9 +28)
  54: (15) if r9 == 0x0 goto pc+3
  56: (79) r9 = *(u64 *)(r9 +0)

This stomps on the r9 reg to do the sk_fullsock check and then when
reading the skops->sk field instead of the sk pointer we get the
sk_fullsock. To fix use similar pattern noted in the previous fix
and use the temp field to save/restore a register used to do
sk_fullsock check.

After the fix the generated xlated code reads,

  52: (7b) *(u64 *)(r9 +32) = r8
  53: (61) r8 = *(u32 *)(r9 +28)
  54: (15) if r9 == 0x0 goto pc+3
  55: (79) r8 = *(u64 *)(r9 +32)
  56: (79) r9 = *(u64 *)(r9 +0)
  57: (05) goto pc+1
  58: (79) r8 = *(u64 *)(r9 +32)

Here r9 register was in-use so r8 is chosen as the temporary register.
In line 52 r8 is saved in temp variable and at line 54 restored in case
fullsock != 0. Finally we handle fullsock == 0 case by restoring at
line 58.

This adds a new macro SOCK_OPS_GET_SK it is almost possible to merge
this with SOCK_OPS_GET_FIELD, but I found the extra branch logic a
bit more confusing than just adding a new macro despite a bit of
duplicating code.

Fixes: 1314ef561102e ("bpf: export bpf_sock for BPF_PROG_TYPE_SOCK_OPS prog 
type")
Signed-off-by: John Fastabend <john.fastab...@gmail.com>
Signed-off-by: Daniel Borkmann <dan...@iogearbox.net>
Acked-by: Song Liu <songliubrav...@fb.com>
Acked-by: Martin KaFai Lau <ka...@fb.com>
Link: 
https://lore.kernel.org/bpf/159718349653.4728.6559437186853473612.stgit@john-Precision-5820-Tower
Signed-off-by: Sasha Levin <sas...@kernel.org>
---
 net/core/filter.c | 49 ++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 38 insertions(+), 11 deletions(-)

diff --git a/net/core/filter.c b/net/core/filter.c
index bd1e46d61d8a1..5c490d473df1d 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -8010,6 +8010,43 @@ static u32 sock_ops_convert_ctx_access(enum 
bpf_access_type type,
                                      offsetof(OBJ, OBJ_FIELD));              \
        } while (0)
 
+#define SOCK_OPS_GET_SK()                                                      
      \
+       do {                                                                  \
+               int fullsock_reg = si->dst_reg, reg = BPF_REG_9, jmp = 1;     \
+               if (si->dst_reg == reg || si->src_reg == reg)                 \
+                       reg--;                                                \
+               if (si->dst_reg == reg || si->src_reg == reg)                 \
+                       reg--;                                                \
+               if (si->dst_reg == si->src_reg) {                             \
+                       *insn++ = BPF_STX_MEM(BPF_DW, si->src_reg, reg,       \
+                                         offsetof(struct bpf_sock_ops_kern,  \
+                                         temp));                             \
+                       fullsock_reg = reg;                                   \
+                       jmp += 2;                                             \
+               }                                                             \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern,     \
+                                               is_fullsock),                 \
+                                     fullsock_reg, si->src_reg,              \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              is_fullsock));                 \
+               *insn++ = BPF_JMP_IMM(BPF_JEQ, fullsock_reg, 0, jmp);         \
+               if (si->dst_reg == si->src_reg)                               \
+                       *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->src_reg,       \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                     temp));                                 \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern, sk),\
+                                     si->dst_reg, si->src_reg,               \
+                                     offsetof(struct bpf_sock_ops_kern, sk));\
+               if (si->dst_reg == si->src_reg) {                             \
+                       *insn++ = BPF_JMP_A(1);                               \
+                       *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->src_reg,       \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                     temp));                                 \
+               }                                                             \
+       } while (0)
+
 #define SOCK_OPS_GET_TCP_SOCK_FIELD(FIELD) \
                SOCK_OPS_GET_FIELD(FIELD, FIELD, struct tcp_sock)
 
@@ -8294,17 +8331,7 @@ static u32 sock_ops_convert_ctx_access(enum 
bpf_access_type type,
                SOCK_OPS_GET_TCP_SOCK_FIELD(bytes_acked);
                break;
        case offsetof(struct bpf_sock_ops, sk):
-               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
-                                               struct bpf_sock_ops_kern,
-                                               is_fullsock),
-                                     si->dst_reg, si->src_reg,
-                                     offsetof(struct bpf_sock_ops_kern,
-                                              is_fullsock));
-               *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1);
-               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
-                                               struct bpf_sock_ops_kern, sk),
-                                     si->dst_reg, si->src_reg,
-                                     offsetof(struct bpf_sock_ops_kern, sk));
+               SOCK_OPS_GET_SK();
                break;
        }
        return insn - insn_buf;
-- 
2.25.1



Reply via email to