We want to use the inet_sock_diag_destroy code to send notifications for more types of TCP events than just socket_close(), so refactor the code to allow this.
Signed-off-by: Sowmini Varadhan <sowmini.varad...@oracle.com> --- include/linux/sock_diag.h | 18 +++++++++++++----- include/uapi/linux/sock_diag.h | 2 ++ net/core/sock.c | 4 ++-- net/core/sock_diag.c | 11 ++++++----- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/include/linux/sock_diag.h b/include/linux/sock_diag.h index 15fe980..df85767 100644 --- a/include/linux/sock_diag.h +++ b/include/linux/sock_diag.h @@ -34,7 +34,7 @@ int sock_diag_put_filterinfo(bool may_report_filterinfo, struct sock *sk, struct sk_buff *skb, int attrtype); static inline -enum sknetlink_groups sock_diag_destroy_group(const struct sock *sk) +enum sknetlink_groups sock_diag_group(const struct sock *sk) { switch (sk->sk_family) { case AF_INET: @@ -43,7 +43,15 @@ enum sknetlink_groups sock_diag_destroy_group(const struct sock *sk) switch (sk->sk_protocol) { case IPPROTO_TCP: - return SKNLGRP_INET_TCP_DESTROY; + switch (sk->sk_state) { + case TCP_ESTABLISHED: + return SKNLGRP_INET_TCP_CONNECTED; + case TCP_SYN_SENT: + case TCP_SYN_RECV: + return SKNLGRP_INET_TCP_3WH; + default: + return SKNLGRP_INET_TCP_DESTROY; + } case IPPROTO_UDP: return SKNLGRP_INET_UDP_DESTROY; default: @@ -67,15 +75,15 @@ enum sknetlink_groups sock_diag_destroy_group(const struct sock *sk) } static inline -bool sock_diag_has_destroy_listeners(const struct sock *sk) +bool sock_diag_has_listeners(const struct sock *sk) { const struct net *n = sock_net(sk); - const enum sknetlink_groups group = sock_diag_destroy_group(sk); + const enum sknetlink_groups group = sock_diag_group(sk); return group != SKNLGRP_NONE && n->diag_nlsk && netlink_has_listeners(n->diag_nlsk, group); } -void sock_diag_broadcast_destroy(struct sock *sk); +void sock_diag_broadcast(struct sock *sk); int sock_diag_destroy(struct sock *sk, int err); #endif diff --git a/include/uapi/linux/sock_diag.h b/include/uapi/linux/sock_diag.h index e592500..4252674 100644 --- a/include/uapi/linux/sock_diag.h +++ b/include/uapi/linux/sock_diag.h @@ -32,6 +32,8 @@ enum sknetlink_groups { SKNLGRP_INET_UDP_DESTROY, SKNLGRP_INET6_TCP_DESTROY, SKNLGRP_INET6_UDP_DESTROY, + SKNLGRP_INET_TCP_3WH, + SKNLGRP_INET_TCP_CONNECTED, __SKNLGRP_MAX, }; #define SKNLGRP_MAX (__SKNLGRP_MAX - 1) diff --git a/net/core/sock.c b/net/core/sock.c index 7e8796a..6684840 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -1600,8 +1600,8 @@ static void __sk_free(struct sock *sk) if (likely(sk->sk_net_refcnt)) sock_inuse_add(sock_net(sk), -1); - if (unlikely(sk->sk_net_refcnt && sock_diag_has_destroy_listeners(sk))) - sock_diag_broadcast_destroy(sk); + if (unlikely(sk->sk_net_refcnt && sock_diag_has_listeners(sk))) + sock_diag_broadcast(sk); else sk_destruct(sk); } diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c index 3312a58..dbd9e65 100644 --- a/net/core/sock_diag.c +++ b/net/core/sock_diag.c @@ -116,14 +116,14 @@ static size_t sock_diag_nlmsg_size(void) + nla_total_size_64bit(sizeof(struct tcp_info))); /* INET_DIAG_INFO */ } -static void sock_diag_broadcast_destroy_work(struct work_struct *work) +static void sock_diag_broadcast_work(struct work_struct *work) { struct broadcast_sk *bsk = container_of(work, struct broadcast_sk, work); struct sock *sk = bsk->sk; const struct sock_diag_handler *hndl; struct sk_buff *skb; - const enum sknetlink_groups group = sock_diag_destroy_group(sk); + const enum sknetlink_groups group = sock_diag_group(sk); int err = -1; WARN_ON(group == SKNLGRP_NONE); @@ -144,11 +144,12 @@ static void sock_diag_broadcast_destroy_work(struct work_struct *work) else kfree_skb(skb); out: - sk_destruct(sk); + if (group <= SKNLGRP_INET6_UDP_DESTROY) + sk_destruct(sk); kfree(bsk); } -void sock_diag_broadcast_destroy(struct sock *sk) +void sock_diag_broadcast(struct sock *sk) { /* Note, this function is often called from an interrupt context. */ struct broadcast_sk *bsk = @@ -156,7 +157,7 @@ void sock_diag_broadcast_destroy(struct sock *sk) if (!bsk) return sk_destruct(sk); bsk->sk = sk; - INIT_WORK(&bsk->work, sock_diag_broadcast_destroy_work); + INIT_WORK(&bsk->work, sock_diag_broadcast_work); queue_work(broadcast_wq, &bsk->work); } -- 1.7.1