On Fri, Jan 12, 2018 at 10:11:11AM -0800, John Fastabend wrote:
> This implements a BPF ULP layer to allow policy enforcement and
> monitoring at the socket layer. In order to support this a new
> program type BPF_PROG_TYPE_SK_MSG is used to run the policy at
> the sendmsg/sendpage hook. To attach the policy to sockets a
> sockmap is used with a new program attach type BPF_SK_MSG_VERDICT.
> 
> Similar to previous sockmap usages when a sock is added to a
> sockmap, via a map update, if the map contains a BPF_SK_MSG_VERDICT
> program type attached then the BPF ULP layer is created on the
> socket and the attached BPF_PROG_TYPE_SK_MSG program is run for
> every msg in sendmsg case and page/offset in sendpage case.
> 
> BPF_PROG_TYPE_SK_MSG Semantics/API:
> 
> BPF_PROG_TYPE_SK_MSG supports only two return codes SK_PASS and
> SK_DROP. Returning SK_DROP free's the copied data in the sendmsg
> case and in the sendpage case leaves the data untouched. Both cases
> return -EACESS to the user. Returning SK_PASS will allow the msg to
> be sent.
> 
> In the sendmsg case data is copied into kernel space buffers before
> running the BPF program. In the sendpage case data is never copied.
> The implication being users may change data after BPF programs run in
> the sendpage case. (A flag will be added to always copy shortly
> if the copy must always be performed).
> 
> The verdict from the BPF_PROG_TYPE_SK_MSG applies to the entire msg
> in the sendmsg() case and the entire page/offset in the sendpage case.
> This avoid ambiguity on how to handle mixed return codes in the
> sendmsg case. The readable/writeable data provided to the program
> in the sendmsg case may not be the entire message, in fact for
> large sends this is likely the case. The data range that can be
> read is part of the sk_msg_md structure. This is because similar
> to the tc bpf_cls case the data is stored in a scatter gather list.
> Future work will address this short-coming to allow users to pull
> in more data if needed (similar to TC BPF).
> 
> The helper msg_redirect_map() can be used to select the socket to
> send the data on. This is used similar to existing redirect use
> cases. This allows policy to redirect msgs.
> 
> Pseudo code simple example:
> 
> The basic logic to attach a program to a socket is as follows,
> 
>   // load the programs
>   bpf_prog_load(SOCKMAP_TCP_MSG_PROG, BPF_PROG_TYPE_SK_MSG,
>               &obj, &msg_prog);
> 
>   // lookup the sockmap
>   bpf_map_msg = bpf_object__find_map_by_name(obj, "my_sock_map");
> 
>   // get fd for sockmap
>   map_fd_msg = bpf_map__fd(bpf_map_msg);
> 
>   // attach program to sockmap
>   bpf_prog_attach(msg_prog, map_fd_msg, BPF_SK_MSG_VERDICT, 0);
> 
> Adding sockets to the map is done in the normal way,
> 
>   // Add a socket 'fd' to sockmap at location 'i'
>   bpf_map_update_elem(map_fd_msg, &i, fd, BPF_ANY);
> 
> After the above any socket attached to "my_sock_map", in this case
> 'fd', will run the BPF msg verdict program (msg_prog) on every
> sendmsg and sendpage system call.
> 
> For a complete example see BPF selftests bpf/sockmap_tcp_msg_*.c and
> test_maps.c
> 
> Implementation notes:
> 
> It seemed the simplest, to me at least, to use a refcnt to ensure
> psock is not lost across the sendmsg copy into the sg, the bpf program
> running on the data in sg_data, and the final pass to the TCP stack.
> Some performance testing may show a better method to do this and avoid
> the refcnt cost, but for now use the simpler method.
> 
> Another item that will come after basic support is in place is
> supporting MSG_MORE flag. At the moment we call sendpages even if
> the MSG_MORE flag is set. An enhancement would be to collect the
> pages into a larger scatterlist and pass down the stack. Notice that
> bpf_tcp_sendmsg() could support this with some additional state saved
> across sendmsg calls. I built the code to support this without having
> to do refactoring work. Other flags TBD include ZEROCOPY flag.
> 
> Yet another detail that needs some thought is the size of scatterlist.
> Currently, we use MAX_SKB_FRAGS simply because this was being used
> already in the TLS case. Future work to improve the kernel sk APIs to
> tune this depending on workload may be useful. This is a trade-off
> between memory usage and B/s performance.
Some minor comments/nits below:

> 
> Signed-off-by: John Fastabend <john.fastab...@gmail.com>
> ---
>  include/linux/bpf.h       |    1 
>  include/linux/bpf_types.h |    1 
>  include/linux/filter.h    |   10 +
>  include/net/tcp.h         |    2 
>  include/uapi/linux/bpf.h  |   28 +++
>  kernel/bpf/sockmap.c      |  485 
> ++++++++++++++++++++++++++++++++++++++++++++-
>  kernel/bpf/syscall.c      |   14 +
>  kernel/bpf/verifier.c     |    5 
>  net/core/filter.c         |  106 ++++++++++
>  9 files changed, 638 insertions(+), 14 deletions(-)
> 
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index 9e03046..14cdb4d 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -21,6 +21,7 @@
>  struct perf_event;
>  struct bpf_prog;
>  struct bpf_map;
> +struct sock;
>  
>  /* map is generic key/value storage optionally accesible by eBPF programs */
>  struct bpf_map_ops {
> diff --git a/include/linux/bpf_types.h b/include/linux/bpf_types.h
> index 19b8349..5e2e8a4 100644
> --- a/include/linux/bpf_types.h
> +++ b/include/linux/bpf_types.h
> @@ -13,6 +13,7 @@
>  BPF_PROG_TYPE(BPF_PROG_TYPE_LWT_XMIT, lwt_xmit)
>  BPF_PROG_TYPE(BPF_PROG_TYPE_SOCK_OPS, sock_ops)
>  BPF_PROG_TYPE(BPF_PROG_TYPE_SK_SKB, sk_skb)
> +BPF_PROG_TYPE(BPF_PROG_TYPE_SK_MSG, sk_msg)
>  #endif
>  #ifdef CONFIG_BPF_EVENTS
>  BPF_PROG_TYPE(BPF_PROG_TYPE_KPROBE, kprobe)
> diff --git a/include/linux/filter.h b/include/linux/filter.h
> index 425056c..f1e9833 100644
> --- a/include/linux/filter.h
> +++ b/include/linux/filter.h
> @@ -507,6 +507,15 @@ struct xdp_buff {
>       struct xdp_rxq_info *rxq;
>  };
>  
> +struct sk_msg_buff {
> +     void *data;
> +     void *data_end;
> +     struct scatterlist sg_data[MAX_SKB_FRAGS];
> +     __u32 key;
> +     __u32 flags;
> +     struct bpf_map *map;
> +};
> +
>  /* Compute the linear packet data range [data, data_end) which
>   * will be accessed by various program types (cls_bpf, act_bpf,
>   * lwt, ...). Subsystems allowing direct data access must (!)
> @@ -769,6 +778,7 @@ int xdp_do_redirect(struct net_device *dev,
>  void bpf_warn_invalid_xdp_action(u32 act);
>  
>  struct sock *do_sk_redirect_map(struct sk_buff *skb);
> +struct sock *do_msg_redirect_map(struct sk_msg_buff *md);
>  
>  #ifdef CONFIG_BPF_JIT
>  extern int bpf_jit_enable;
> diff --git a/include/net/tcp.h b/include/net/tcp.h
> index a99ceb8..7f56c3c 100644
> --- a/include/net/tcp.h
> +++ b/include/net/tcp.h
> @@ -1984,6 +1984,7 @@ static inline void tcp_listendrop(const struct sock *sk)
>  
>  enum {
>       TCP_ULP_TLS,
> +     TCP_ULP_BPF,
>  };
>  
>  struct tcp_ulp_ops {
> @@ -2001,6 +2002,7 @@ struct tcp_ulp_ops {
>  int tcp_register_ulp(struct tcp_ulp_ops *type);
>  void tcp_unregister_ulp(struct tcp_ulp_ops *type);
>  int tcp_set_ulp(struct sock *sk, const char *name);
> +int tcp_set_ulp_id(struct sock *sk, const int ulp);
>  void tcp_get_available_ulp(char *buf, size_t len);
>  void tcp_cleanup_ulp(struct sock *sk);
>  
> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> index 405317f..bf649ae 100644
> --- a/include/uapi/linux/bpf.h
> +++ b/include/uapi/linux/bpf.h
> @@ -133,6 +133,7 @@ enum bpf_prog_type {
>       BPF_PROG_TYPE_SOCK_OPS,
>       BPF_PROG_TYPE_SK_SKB,
>       BPF_PROG_TYPE_CGROUP_DEVICE,
> +     BPF_PROG_TYPE_SK_MSG,
>  };
>  
>  enum bpf_attach_type {
> @@ -143,6 +144,7 @@ enum bpf_attach_type {
>       BPF_SK_SKB_STREAM_PARSER,
>       BPF_SK_SKB_STREAM_VERDICT,
>       BPF_CGROUP_DEVICE,
> +     BPF_SK_MSG_VERDICT,
>       __MAX_BPF_ATTACH_TYPE
>  };
>  
> @@ -687,6 +689,15 @@ enum bpf_attach_type {
>   * int bpf_override_return(pt_regs, rc)
>   *   @pt_regs: pointer to struct pt_regs
>   *   @rc: the return value to set
> + *
> + * int bpf_msg_redirect_map(map, key, flags)
> + *     Redirect msg to a sock in map using key as a lookup key for the
> + *     sock in map.
> + *     @map: pointer to sockmap
> + *     @key: key to lookup sock in map
> + *     @flags: reserved for future use
> + *     Return: SK_PASS
> + *
>   */
>  #define __BPF_FUNC_MAPPER(FN)                \
>       FN(unspec),                     \
> @@ -747,7 +758,8 @@ enum bpf_attach_type {
>       FN(perf_event_read_value),      \
>       FN(perf_prog_read_value),       \
>       FN(getsockopt),                 \
> -     FN(override_return),
> +     FN(override_return),            \
> +     FN(msg_redirect_map),
>  
>  /* integer value in 'imm' field of BPF_CALL instruction selects which helper
>   * function eBPF program intends to call
> @@ -909,6 +921,20 @@ enum sk_action {
>       SK_PASS,
>  };
>  
> +/* User return codes for SK_MSG prog type. */
> +enum sk_msg_action {
> +     SK_MSG_DROP = 0,
> +     SK_MSG_PASS,
> +};
> +
> +/* user accessible metadata for SK_MSG packet hook, new fields must
> + * be added to the end of this structure
> + */
> +struct sk_msg_md {
> +     __u32 data;
> +     __u32 data_end;
> +};
> +
>  #define BPF_TAG_SIZE 8
>  
>  struct bpf_prog_info {
> diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
> index 972608f..5793f3a 100644
> --- a/kernel/bpf/sockmap.c
> +++ b/kernel/bpf/sockmap.c
> @@ -38,6 +38,7 @@
>  #include <linux/skbuff.h>
>  #include <linux/workqueue.h>
>  #include <linux/list.h>
> +#include <linux/mm.h>
>  #include <net/strparser.h>
>  #include <net/tcp.h>
>  
> @@ -47,6 +48,7 @@
>  struct bpf_stab {
>       struct bpf_map map;
>       struct sock **sock_map;
> +     struct bpf_prog *bpf_tx_msg;
>       struct bpf_prog *bpf_parse;
>       struct bpf_prog *bpf_verdict;
>  };
> @@ -74,6 +76,7 @@ struct smap_psock {
>       struct sk_buff *save_skb;
>  
>       struct strparser strp;
> +     struct bpf_prog *bpf_tx_msg;
>       struct bpf_prog *bpf_parse;
>       struct bpf_prog *bpf_verdict;
>       struct list_head maps;
> @@ -90,6 +93,8 @@ struct smap_psock {
>       void (*save_state_change)(struct sock *sk);
>  };
>  
> +static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
> +
>  static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
>  {
>       return rcu_dereference_sk_user_data(sk);
> @@ -99,8 +104,439 @@ enum __sk_action {
>       __SK_DROP = 0,
>       __SK_PASS,
>       __SK_REDIRECT,
> +     __SK_NONE,
>  };
>  
> +static int memcopy_from_iter(struct sock *sk, struct scatterlist *sg,
> +                          int sg_num, struct iov_iter *from, int bytes)
> +{
> +     int i, rc = 0;
> +
> +     for (i = 0; i < sg_num; ++i) {
> +             int copy = sg[i].length;
> +             char *to = sg_virt(&sg[i]);
> +
> +             if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
> +                     rc = copy_from_iter_nocache(to, copy, from);
> +             else
> +                     rc = copy_from_iter(to, copy, from);
> +
> +             if (rc != copy) {
> +                     rc = -EFAULT;
> +                     goto out;
> +             }
> +
> +             bytes -= copy;
> +             if (!bytes)
> +                     break;
> +     }
> +out:
> +     return rc;
> +}
> +
> +static int bpf_tcp_push(struct sock *sk, struct scatterlist *sg,
> +                     int *sg_end, int flags, bool charge)
> +{
> +     int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
> +     int offset, ret = 0;
> +     struct page *p;
> +     size_t size;
> +
> +     size = sg->length;
> +     offset = sg->offset;
> +
> +     while (1) {
> +             if (sg_is_last(sg))
> +                     sendpage_flags = flags;
> +
> +             tcp_rate_check_app_limited(sk);
> +             p = sg_page(sg);
> +retry:
> +             ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
> +             if (ret != size) {
> +                     if (ret > 0) {
> +                             offset += ret;
> +                             size -= ret;
> +                             goto retry;
> +                     }
> +
> +                     if (charge)
> +                             sk_mem_uncharge(sk,
> +                                             sg->length - size - sg->offset);
> +
> +                     sg->offset = offset;
> +                     sg->length = size;
> +                     return ret;
> +             }
> +
> +             put_page(p);
> +             if (charge)
> +                     sk_mem_uncharge(sk, sg->length);
> +             *sg_end += 1;
> +             sg = sg_next(sg);
> +             if (!sg)
> +                     break;
> +
> +             offset = sg->offset;
> +             size = sg->length;
> +     }
> +
> +     return 0;
> +}
> +
> +static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
> +{
> +     md->data = sg_virt(md->sg_data);
> +     md->data_end = md->data + md->sg_data->length;
> +}
> +
> +static void return_mem_sg(struct sock *sk, struct scatterlist *sg, int end)
> +{
> +     int i;
> +
> +     for (i = 0; i < end; ++i)
> +             sk_mem_uncharge(sk, sg[i].length);
> +}
> +
> +static int free_sg(struct sock *sk, struct scatterlist *sg, int start, int 
> len)
> +{
> +     int i, free = 0;
> +
> +     for (i = start; i < len; ++i) {
> +             free += sg[i].length;
> +             sk_mem_uncharge(sk, sg[i].length);
> +             put_page(sg_page(&sg[i]));
> +     }
> +
> +     return free;
> +}
> +
> +static unsigned int smap_do_tx_msg(struct sock *sk,
> +                                struct smap_psock *psock,
> +                                struct sk_msg_buff *md)
> +{
> +     struct bpf_prog *prog;
> +     unsigned int rc, _rc;
> +
> +     preempt_disable();
Why preempt_disable() is needed?

> +     rcu_read_lock();
> +
> +     /* If the policy was removed mid-send then default to 'accept' */
> +     prog = READ_ONCE(psock->bpf_tx_msg);
> +     if (unlikely(!prog)) {
> +             _rc = SK_PASS;
> +             goto verdict;
> +     }
> +
> +     bpf_compute_data_pointers_sg(md);
> +     _rc = (*prog->bpf_func)(md, prog->insnsi);
> +
> +verdict:
> +     rcu_read_unlock();
> +     preempt_enable();
> +
> +     /* Moving return codes from UAPI namespace into internal namespace */
> +     rc = ((_rc == SK_PASS) ?
> +           (md->map ? __SK_REDIRECT : __SK_PASS) :
> +           __SK_DROP);
> +
> +     return rc;
> +}
> +
> +static int bpf_tcp_sendmsg_do_redirect(struct scatterlist *sg, int sg_num,
> +                                    struct sk_msg_buff *md, int flags)
> +{
> +     int i, sg_curr = 0, err, free;
> +     struct smap_psock *psock;
> +     struct sock *sk;
> +
> +     rcu_read_lock();
> +     sk = do_msg_redirect_map(md);
> +     if (unlikely(!sk))
> +             goto out_rcu;
> +
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock))
> +             goto out_rcu;
> +
> +     if (!refcount_inc_not_zero(&psock->refcnt))
> +             goto out_rcu;
> +
> +     rcu_read_unlock();
> +     lock_sock(sk);
> +     err = bpf_tcp_push(sk, sg, &sg_curr, flags, false);
> +     if (unlikely(err))
> +             goto out;
> +     release_sock(sk);
> +     smap_release_sock(psock, sk);
> +     return 0;
> +out_rcu:
> +     rcu_read_unlock();
> +out:
> +     for (i = sg_curr; i < sg_num; ++i) {
> +             free += sg[i].length;
free is not init.

> +             put_page(sg_page(&sg[i]));
> +     }
> +     return free;
> +}
> +
> +static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
> +{
> +     int err = 0, eval = __SK_NONE, sg_size = 0, sg_num = 0;
> +     int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
> +     struct sk_msg_buff md = {0};
> +     struct smap_psock *psock;
> +     size_t copy, copied = 0;
> +     struct scatterlist *sg;
> +     long timeo;
> +
> +     sg = md.sg_data;
> +     sg_init_table(sg, MAX_SKB_FRAGS);
> +
> +     /* Its possible a sock event or user removed the psock _but_ the ops
> +      * have not been reprogrammed yet so we get here. In this case fallback
> +      * to tcp_sendmsg. Note this only works because we _only_ ever allow
> +      * a single ULP there is no hierarchy here.
> +      */
> +     rcu_read_lock();
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock)) {
> +             rcu_read_unlock();
> +             return tcp_sendmsg(sk, msg, size);
> +     }
> +
> +     /* Increment the psock refcnt to ensure its not released while sending a
> +      * message. Required because sk lookup and bpf programs are used in
> +      * separate rcu critical sections. Its OK if we lose the map entry
> +      * but we can't lose the sock reference, possible when the refcnt hits
> +      * zero and garbage collection calls sock_put().
> +      */
> +     if (!refcount_inc_not_zero(&psock->refcnt)) {
> +             rcu_read_unlock();
> +             return tcp_sendmsg(sk, msg, size);
> +     }
> +
> +     rcu_read_unlock();
> +
> +     lock_sock(sk);
> +     timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
> +
> +     while (msg_data_left(msg)) {
> +             int sg_curr;
> +
> +             if (sk->sk_err) {
> +                     err = sk->sk_err;
> +                     goto out_err;
> +             }
> +
> +             copy = msg_data_left(msg);
> +             if (!sk_stream_memory_free(sk))
> +                     goto wait_for_sndbuf;
> +
> +             /* sg_size indicates bytes already allocated and sg_num
> +              * is last sg element used. This is used when alloc_sg
s/alloc_sg/sk_alloc_sg/

> +              * partially allocates a scatterlist and then is sent
> +              * to wait for memory. In normal case (no memory pressure)
> +              * both sg_nun and sg_size are zero.
s/sg_nun/sg_num/

> +              */
> +             copy = copy - sg_size;
> +             err = sk_alloc_sg(sk, copy, sg, &sg_num, &sg_size, 0);
> +             if (err) {
> +                     if (err != -ENOSPC)
> +                             goto wait_for_memory;
> +                     copy = sg_size;
> +             }
> +
> +             err = memcopy_from_iter(sk, sg, sg_num, &msg->msg_iter, copy);
> +             if (err < 0) {
> +                     free_sg(sk, sg, 0, sg_num);
> +                     goto out_err;
> +             }
> +
> +             copied += copy;
> +
> +             /* If msg is larger than MAX_SKB_FRAGS we can send multiple
> +              * scatterlists per msg. However BPF decisions apply to the
> +              * entire msg.
> +              */
> +             if (eval == __SK_NONE)
> +                     eval = smap_do_tx_msg(sk, psock, &md);
> +
> +             switch (eval) {
> +             case __SK_PASS:
> +                     sg_mark_end(sg + sg_num - 1);
> +                     err = bpf_tcp_push(sk, sg, &sg_curr, flags, true);
> +                     if (unlikely(err)) {
> +                             copied -= free_sg(sk, sg, sg_curr, sg_num);
> +                             goto out_err;
> +                     }
> +                     break;
> +             case __SK_REDIRECT:
> +                     sg_mark_end(sg + sg_num - 1);
> +                     goto do_redir;
> +             case __SK_DROP:
> +             default:
> +                     copied -= free_sg(sk, sg, 0, sg_num);
> +                     goto out_err;
> +             }
> +
> +             sg_num = 0;
> +             sg_size = 0;
> +             continue;
> +wait_for_sndbuf:
> +             set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
> +wait_for_memory:
> +             err = sk_stream_wait_memory(sk, &timeo);
> +             if (err)
> +                     goto out_err;
> +     }
> +out_err:
> +     if (err < 0)
> +             err = sk_stream_error(sk, msg->msg_flags, err);
> +     release_sock(sk);
> +     smap_release_sock(psock, sk);
> +     return copied ? copied : err;
> +
> +do_redir:
> +     /* To avoid deadlock with multiple socks all doing redirects to
> +      * each other we must first drop the current sock lock and release
> +      * the psock. Then get the redirect socket (assuming it still
> +      * exists), take it's lock, and finally do the send here. If the
> +      * redirect fails there is nothing to do, we don't want to blame
> +      * the sender for remote socket failures. Instead we simply
> +      * continue making forward progress.
> +      */
> +     return_mem_sg(sk, sg, sg_num);
> +     release_sock(sk);
> +     smap_release_sock(psock, sk);
> +     copied -= bpf_tcp_sendmsg_do_redirect(sg, sg_num, &md, flags);
> +     return copied;
For __SK_REDIRECT case, before returning, should 'msg_data_left(msg)' be checked
first?  Or msg_data_left(msg) will always be 0 here?

> +}
> +
> +static int bpf_tcp_sendpage_do_redirect(struct page *page, int offset,
> +                                     size_t size, int flags,
> +                                     struct sk_msg_buff *md)
> +{
> +     struct smap_psock *psock;
> +     struct sock *sk;
> +     int rc;
> +
> +     rcu_read_lock();
> +     sk = do_msg_redirect_map(md);
> +     if (unlikely(!sk))
> +             goto out_rcu;
> +
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock))
> +             goto out_rcu;
> +
> +     if (!refcount_inc_not_zero(&psock->refcnt))
> +             goto out_rcu;
> +
> +     rcu_read_unlock();
> +
> +     lock_sock(sk);
> +     rc = tcp_sendpage_locked(sk, page, offset, size, flags);
> +     release_sock(sk);
> +
> +     smap_release_sock(psock, sk);
> +     return rc;
> +out_rcu:
> +     rcu_read_unlock();
> +     return -EINVAL;
> +}
> +
> +static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
> +                         int offset, size_t size, int flags)
> +{
> +     struct smap_psock *psock;
> +     int rc, _rc = __SK_PASS;
> +     struct bpf_prog *prog;
> +     struct sk_msg_buff md;
> +
> +     preempt_disable();
> +     rcu_read_lock();
> +     psock = smap_psock_sk(sk);
> +     if (unlikely(!psock))
> +             goto verdict;
> +
> +     /* If the policy was removed mid-send then default to 'accept' */
> +     prog = READ_ONCE(psock->bpf_tx_msg);
> +     if (unlikely(!prog))
> +             goto verdict;
> +
> +     /* Calculate pkt data pointers and run BPF program */
> +     md.data = page_address(page) + offset;
> +     md.data_end = md.data + size;
> +     _rc = (*prog->bpf_func)(&md, prog->insnsi);
> +
> +verdict:
> +     rcu_read_unlock();
> +     preempt_enable();
> +
> +     /* Moving return codes from UAPI namespace into internal namespace */
> +     rc = ((_rc == SK_PASS) ? __SK_PASS : __SK_DROP);
> +
> +     switch (rc) {
> +     case __SK_PASS:
> +             lock_sock(sk);
> +             rc = tcp_sendpage_locked(sk, page, offset, size, flags);
> +             release_sock(sk);
> +             break;
> +     case __SK_REDIRECT:
> +             smap_release_sock(psock, sk);
smap_release_sock() is only needed in __SK_REDIRECT case?

> +             rc = bpf_tcp_sendpage_do_redirect(page, offset, size, flags,
> +                                               &md);
> +             break;
> +     case __SK_DROP:
> +     default:
> +             rc = -EACCES;
> +     }
> +
> +     return rc;
> +}
> +
> +static int bpf_tcp_msg_add(struct smap_psock *psock,
> +                        struct sock *sk,
> +                        struct bpf_prog *tx_msg)
> +{
> +     struct bpf_prog *orig_tx_msg;
> +
> +     orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
> +     if (orig_tx_msg)
> +             bpf_prog_put(orig_tx_msg);
> +
> +     return tcp_set_ulp_id(sk, TCP_ULP_BPF);
> +}
> +
> +struct proto tcp_bpf_proto;
> +static int bpf_tcp_init(struct sock *sk)
> +{
> +     sk->sk_prot = &tcp_bpf_proto;
> +     return 0;
> +}
> +
> +static void bpf_tcp_release(struct sock *sk)
> +{
> +     sk->sk_prot = &tcp_prot;
> +}
> +
> +static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
> +     .name                   = "bpf_tcp",
> +     .uid                    = TCP_ULP_BPF,
> +     .owner                  = NULL,
> +     .init                   = bpf_tcp_init,
> +     .release                = bpf_tcp_release,
> +};
> +
> +static int bpf_tcp_ulp_register(void)
> +{
> +     tcp_bpf_proto = tcp_prot;
> +     tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
> +     tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
> +     return tcp_register_ulp(&bpf_tcp_ulp_ops);
> +}
> +
>  static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
>  {
>       struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
> @@ -165,8 +601,6 @@ static void smap_report_sk_error(struct smap_psock 
> *psock, int err)
>       sk->sk_error_report(sk);
>  }
>  
> -static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
> -
>  /* Called with lock_sock(sk) held */
>  static void smap_state_change(struct sock *sk)
>  {
> @@ -317,6 +751,7 @@ static void smap_write_space(struct sock *sk)
>  
>  static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
>  {
> +     tcp_cleanup_ulp(sk);
>       if (!psock->strp_enabled)
>               return;
>       sk->sk_data_ready = psock->save_data_ready;
> @@ -384,7 +819,6 @@ static int smap_parse_func_strparser(struct strparser 
> *strp,
>       return rc;
>  }
>  
> -
>  static int smap_read_sock_done(struct strparser *strp, int err)
>  {
>       return err;
> @@ -456,6 +890,8 @@ static void smap_gc_work(struct work_struct *w)
>               bpf_prog_put(psock->bpf_parse);
>       if (psock->bpf_verdict)
>               bpf_prog_put(psock->bpf_verdict);
> +     if (psock->bpf_tx_msg)
> +             bpf_prog_put(psock->bpf_tx_msg);
>  
>       list_for_each_entry_safe(e, tmp, &psock->maps, list) {
>               list_del(&e->list);
> @@ -491,8 +927,7 @@ static struct smap_psock *smap_init_psock(struct sock 
> *sock,
>  
>  static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
>  {
> -     struct bpf_stab *stab;
> -     int err = -EINVAL;
> +     struct bpf_stab *stab; int err = -EINVAL;
>       u64 cost;
>  
>       if (!capable(CAP_NET_ADMIN))
> @@ -506,6 +941,10 @@ static struct bpf_map *sock_map_alloc(union bpf_attr 
> *attr)
>       if (attr->value_size > KMALLOC_MAX_SIZE)
>               return ERR_PTR(-E2BIG);
>  
> +     err = bpf_tcp_ulp_register();
> +     if (err && err != -EEXIST)
> +             return ERR_PTR(err);
> +
>       stab = kzalloc(sizeof(*stab), GFP_USER);
>       if (!stab)
>               return ERR_PTR(-ENOMEM);
> @@ -590,6 +1029,8 @@ static void sock_map_free(struct bpf_map *map)
>               bpf_prog_put(stab->bpf_verdict);
>       if (stab->bpf_parse)
>               bpf_prog_put(stab->bpf_parse);
> +     if (stab->bpf_tx_msg)
> +             bpf_prog_put(stab->bpf_tx_msg);
>  
>       sock_map_remove_complete(stab);
>  }
> @@ -684,7 +1125,7 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>  {
>       struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
>       struct smap_psock_map_entry *e = NULL;
> -     struct bpf_prog *verdict, *parse;
> +     struct bpf_prog *verdict, *parse, *tx_msg;
>       struct sock *osock, *sock;
>       struct smap_psock *psock;
>       u32 i = *(u32 *)key;
> @@ -710,6 +1151,7 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>        */
>       verdict = READ_ONCE(stab->bpf_verdict);
>       parse = READ_ONCE(stab->bpf_parse);
> +     tx_msg = READ_ONCE(stab->bpf_tx_msg);
>  
>       if (parse && verdict) {
>               /* bpf prog refcnt may be zero if a concurrent attach operation
> @@ -728,6 +1170,17 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>               }
>       }
>  
> +     if (tx_msg) {
> +             tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);
> +             if (IS_ERR(tx_msg)) {
> +                     if (verdict)
> +                             bpf_prog_put(verdict);
> +                     if (parse)
> +                             bpf_prog_put(parse);
> +                     return PTR_ERR(tx_msg);
> +             }
> +     }
> +
>       write_lock_bh(&sock->sk_callback_lock);
>       psock = smap_psock_sk(sock);
>  
> @@ -742,7 +1195,14 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>                       err = -EBUSY;
>                       goto out_progs;
>               }
> -             refcount_inc(&psock->refcnt);
> +             if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
> +                     err = -EBUSY;
> +                     goto out_progs;
> +             }
> +             if (!refcount_inc_not_zero(&psock->refcnt)) {
> +                     err = -EAGAIN;
> +                     goto out_progs;
> +             }
>       } else {
>               psock = smap_init_psock(sock, stab);
>               if (IS_ERR(psock)) {
> @@ -763,6 +1223,12 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>       /* 3. At this point we have a reference to a valid psock that is
>        * running. Attach any BPF programs needed.
>        */
> +     if (tx_msg) {
> +             err = bpf_tcp_msg_add(psock, sock, tx_msg);
> +             if (err)
> +                     goto out_free;
> +     }
> +
>       if (parse && verdict && !psock->strp_enabled) {
>               err = smap_init_sock(psock, sock);
>               if (err)
> @@ -798,6 +1264,8 @@ static int sock_map_ctx_update_elem(struct 
> bpf_sock_ops_kern *skops,
>               bpf_prog_put(verdict);
>       if (parse)
>               bpf_prog_put(parse);
> +     if (tx_msg)
> +             bpf_prog_put(tx_msg);
>       write_unlock_bh(&sock->sk_callback_lock);
>       kfree(e);
>       return err;
> @@ -812,6 +1280,9 @@ int sock_map_prog(struct bpf_map *map, struct bpf_prog 
> *prog, u32 type)
>               return -EINVAL;
>  
>       switch (type) {
> +     case BPF_SK_MSG_VERDICT:
> +             orig = xchg(&stab->bpf_tx_msg, prog);
> +             break;
>       case BPF_SK_SKB_STREAM_PARSER:
>               orig = xchg(&stab->bpf_parse, prog);
>               break;
> diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
> index ebf0fb2..d32f093 100644
> --- a/kernel/bpf/syscall.c
> +++ b/kernel/bpf/syscall.c
> @@ -1267,7 +1267,8 @@ static int bpf_obj_get(const union bpf_attr *attr)
>  
>  #define BPF_PROG_ATTACH_LAST_FIELD attach_flags
>  
> -static int sockmap_get_from_fd(const union bpf_attr *attr, bool attach)
> +static int sockmap_get_from_fd(const union bpf_attr *attr,
> +                            int type, bool attach)
>  {
>       struct bpf_prog *prog = NULL;
>       int ufd = attr->target_fd;
> @@ -1281,8 +1282,7 @@ static int sockmap_get_from_fd(const union bpf_attr 
> *attr, bool attach)
>               return PTR_ERR(map);
>  
>       if (attach) {
> -             prog = bpf_prog_get_type(attr->attach_bpf_fd,
> -                                      BPF_PROG_TYPE_SK_SKB);
> +             prog = bpf_prog_get_type(attr->attach_bpf_fd, type);
>               if (IS_ERR(prog)) {
>                       fdput(f);
>                       return PTR_ERR(prog);
> @@ -1334,9 +1334,11 @@ static int bpf_prog_attach(const union bpf_attr *attr)
>       case BPF_CGROUP_DEVICE:
>               ptype = BPF_PROG_TYPE_CGROUP_DEVICE;
>               break;
> +     case BPF_SK_MSG_VERDICT:
> +             return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_MSG, true);
>       case BPF_SK_SKB_STREAM_PARSER:
>       case BPF_SK_SKB_STREAM_VERDICT:
> -             return sockmap_get_from_fd(attr, true);
> +             return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_SKB, true);
>       default:
>               return -EINVAL;
>       }
> @@ -1389,9 +1391,11 @@ static int bpf_prog_detach(const union bpf_attr *attr)
>       case BPF_CGROUP_DEVICE:
>               ptype = BPF_PROG_TYPE_CGROUP_DEVICE;
>               break;
> +     case BPF_SK_MSG_VERDICT:
> +             return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_MSG, false);
>       case BPF_SK_SKB_STREAM_PARSER:
>       case BPF_SK_SKB_STREAM_VERDICT:
> -             return sockmap_get_from_fd(attr, false);
> +             return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_SKB, false);
>       default:
>               return -EINVAL;
>       }
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index a2b2112..15c5c2a 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -1240,6 +1240,7 @@ static bool may_access_direct_pkt_data(struct 
> bpf_verifier_env *env,
>       case BPF_PROG_TYPE_XDP:
>       case BPF_PROG_TYPE_LWT_XMIT:
>       case BPF_PROG_TYPE_SK_SKB:
> +     case BPF_PROG_TYPE_SK_MSG:
>               if (meta)
>                       return meta->pkt_access;
>  
> @@ -2041,7 +2042,8 @@ static int check_map_func_compatibility(struct 
> bpf_verifier_env *env,
>       case BPF_MAP_TYPE_SOCKMAP:
>               if (func_id != BPF_FUNC_sk_redirect_map &&
>                   func_id != BPF_FUNC_sock_map_update &&
> -                 func_id != BPF_FUNC_map_delete_elem)
> +                 func_id != BPF_FUNC_map_delete_elem &&
> +                 func_id != BPF_FUNC_msg_redirect_map)
>                       goto error;
>               break;
>       default:
> @@ -2079,6 +2081,7 @@ static int check_map_func_compatibility(struct 
> bpf_verifier_env *env,
>                       goto error;
>               break;
>       case BPF_FUNC_sk_redirect_map:
> +     case BPF_FUNC_msg_redirect_map:
>               if (map->map_type != BPF_MAP_TYPE_SOCKMAP)
>                       goto error;
>               break;
> diff --git a/net/core/filter.c b/net/core/filter.c
> index acdb94c..ca87b8d 100644
> --- a/net/core/filter.c
> +++ b/net/core/filter.c
> @@ -1881,6 +1881,44 @@ struct sock *do_sk_redirect_map(struct sk_buff *skb)
>       .arg4_type      = ARG_ANYTHING,
>  };
>  
> +BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
> +        struct bpf_map *, map, u32, key, u64, flags)
> +{
> +     /* If user passes invalid input drop the packet. */
> +     if (unlikely(flags))
> +             return SK_DROP;
> +
> +     msg->key = key;
> +     msg->flags = flags;
> +     msg->map = map;
> +
> +     return SK_PASS;
> +}
> +
> +struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
> +{
> +     struct sock *sk = NULL;
> +
> +     if (msg->map) {
> +             sk = __sock_map_lookup_elem(msg->map, msg->key);
> +
> +             msg->key = 0;
> +             msg->map = NULL;
> +     }
> +
> +     return sk;
> +}
> +
> +static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
> +     .func           = bpf_msg_redirect_map,
> +     .gpl_only       = false,
> +     .ret_type       = RET_INTEGER,
> +     .arg1_type      = ARG_PTR_TO_CTX,
> +     .arg2_type      = ARG_CONST_MAP_PTR,
> +     .arg3_type      = ARG_ANYTHING,
> +     .arg4_type      = ARG_ANYTHING,
> +};
> +
>  BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
>  {
>       return task_get_classid(skb);
> @@ -3513,6 +3551,16 @@ static unsigned long bpf_xdp_copy(void *dst_buff, 
> const void *src_buff,
>       }
>  }
>  
> +static const struct bpf_func_proto *sk_msg_func_proto(enum bpf_func_id 
> func_id)
> +{
> +     switch (func_id) {
> +     case BPF_FUNC_msg_redirect_map:
> +             return &bpf_msg_redirect_map_proto;
> +     default:
> +             return bpf_base_func_proto(func_id);
> +     }
> +}
> +
>  static const struct bpf_func_proto *sk_skb_func_proto(enum bpf_func_id 
> func_id)
>  {
>       switch (func_id) {
> @@ -3892,6 +3940,32 @@ static bool sk_skb_is_valid_access(int off, int size,
>       return bpf_skb_is_valid_access(off, size, type, info);
>  }
>  
> +static bool sk_msg_is_valid_access(int off, int size,
> +                                enum bpf_access_type type,
> +                                struct bpf_insn_access_aux *info)
> +{
> +     if (type == BPF_WRITE)
> +             return false;
> +
> +     switch (off) {
> +     case offsetof(struct sk_msg_md, data):
> +             info->reg_type = PTR_TO_PACKET;
> +             break;
> +     case offsetof(struct sk_msg_md, data_end):
> +             info->reg_type = PTR_TO_PACKET_END;
> +             break;
> +     }
> +
> +     if (off < 0 || off >= sizeof(struct sk_msg_md))
> +             return false;
> +     if (off % size != 0)
> +             return false;
> +     if (size != sizeof(__u32))
> +             return false;
> +
> +     return true;
> +}
> +
>  static u32 bpf_convert_ctx_access(enum bpf_access_type type,
>                                 const struct bpf_insn *si,
>                                 struct bpf_insn *insn_buf,
> @@ -4522,6 +4596,29 @@ static u32 sk_skb_convert_ctx_access(enum 
> bpf_access_type type,
>       return insn - insn_buf;
>  }
>  
> +static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
> +                                  const struct bpf_insn *si,
> +                                  struct bpf_insn *insn_buf,
> +                                  struct bpf_prog *prog, u32 *target_size)
> +{
> +     struct bpf_insn *insn = insn_buf;
> +
> +     switch (si->off) {
> +     case offsetof(struct sk_msg_md, data):
> +             *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, 
> data),
> +                                   si->dst_reg, si->src_reg,
> +                                   offsetof(struct sk_msg_buff, data));
> +             break;
> +     case offsetof(struct sk_msg_md, data_end):
> +             *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, 
> data_end),
> +                                   si->dst_reg, si->src_reg,
> +                                   offsetof(struct sk_msg_buff, data_end));
> +             break;
> +     }
> +
> +     return insn - insn_buf;
> +}
> +
>  const struct bpf_verifier_ops sk_filter_verifier_ops = {
>       .get_func_proto         = sk_filter_func_proto,
>       .is_valid_access        = sk_filter_is_valid_access,
> @@ -4611,6 +4708,15 @@ static u32 sk_skb_convert_ctx_access(enum 
> bpf_access_type type,
>  const struct bpf_prog_ops sk_skb_prog_ops = {
>  };
>  
> +const struct bpf_verifier_ops sk_msg_verifier_ops = {
> +     .get_func_proto         = sk_msg_func_proto,
> +     .is_valid_access        = sk_msg_is_valid_access,
> +     .convert_ctx_access     = sk_msg_convert_ctx_access,
> +};
> +
> +const struct bpf_prog_ops sk_msg_prog_ops = {
> +};
> +
>  int sk_detach_filter(struct sock *sk)
>  {
>       int ret = -ENOENT;
> 

Reply via email to