We want to use the inet_sock_diag_destroy code to send notifications
for more types of TCP events than just socket_close(), so refactor
the code to allow this.

Signed-off-by: Sowmini Varadhan <sowmini.varad...@oracle.com>
---
 include/linux/sock_diag.h      |   18 +++++++++++++-----
 include/uapi/linux/sock_diag.h |    2 ++
 net/core/sock.c                |    4 ++--
 net/core/sock_diag.c           |   11 ++++++-----
 4 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/include/linux/sock_diag.h b/include/linux/sock_diag.h
index 15fe980..df85767 100644
--- a/include/linux/sock_diag.h
+++ b/include/linux/sock_diag.h
@@ -34,7 +34,7 @@ int sock_diag_put_filterinfo(bool may_report_filterinfo, 
struct sock *sk,
                             struct sk_buff *skb, int attrtype);
 
 static inline
-enum sknetlink_groups sock_diag_destroy_group(const struct sock *sk)
+enum sknetlink_groups sock_diag_group(const struct sock *sk)
 {
        switch (sk->sk_family) {
        case AF_INET:
@@ -43,7 +43,15 @@ enum sknetlink_groups sock_diag_destroy_group(const struct 
sock *sk)
 
                switch (sk->sk_protocol) {
                case IPPROTO_TCP:
-                       return SKNLGRP_INET_TCP_DESTROY;
+                       switch (sk->sk_state) {
+                       case TCP_ESTABLISHED:
+                               return SKNLGRP_INET_TCP_CONNECTED;
+                       case TCP_SYN_SENT:
+                       case TCP_SYN_RECV:
+                               return SKNLGRP_INET_TCP_3WH;
+                       default:
+                               return SKNLGRP_INET_TCP_DESTROY;
+                       }
                case IPPROTO_UDP:
                        return SKNLGRP_INET_UDP_DESTROY;
                default:
@@ -67,15 +75,15 @@ enum sknetlink_groups sock_diag_destroy_group(const struct 
sock *sk)
 }
 
 static inline
-bool sock_diag_has_destroy_listeners(const struct sock *sk)
+bool sock_diag_has_listeners(const struct sock *sk)
 {
        const struct net *n = sock_net(sk);
-       const enum sknetlink_groups group = sock_diag_destroy_group(sk);
+       const enum sknetlink_groups group = sock_diag_group(sk);
 
        return group != SKNLGRP_NONE && n->diag_nlsk &&
                netlink_has_listeners(n->diag_nlsk, group);
 }
-void sock_diag_broadcast_destroy(struct sock *sk);
+void sock_diag_broadcast(struct sock *sk);
 
 int sock_diag_destroy(struct sock *sk, int err);
 #endif
diff --git a/include/uapi/linux/sock_diag.h b/include/uapi/linux/sock_diag.h
index e592500..4252674 100644
--- a/include/uapi/linux/sock_diag.h
+++ b/include/uapi/linux/sock_diag.h
@@ -32,6 +32,8 @@ enum sknetlink_groups {
        SKNLGRP_INET_UDP_DESTROY,
        SKNLGRP_INET6_TCP_DESTROY,
        SKNLGRP_INET6_UDP_DESTROY,
+       SKNLGRP_INET_TCP_3WH,
+       SKNLGRP_INET_TCP_CONNECTED,
        __SKNLGRP_MAX,
 };
 #define SKNLGRP_MAX    (__SKNLGRP_MAX - 1)
diff --git a/net/core/sock.c b/net/core/sock.c
index 7e8796a..6684840 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1600,8 +1600,8 @@ static void __sk_free(struct sock *sk)
        if (likely(sk->sk_net_refcnt))
                sock_inuse_add(sock_net(sk), -1);
 
-       if (unlikely(sk->sk_net_refcnt && sock_diag_has_destroy_listeners(sk)))
-               sock_diag_broadcast_destroy(sk);
+       if (unlikely(sk->sk_net_refcnt && sock_diag_has_listeners(sk)))
+               sock_diag_broadcast(sk);
        else
                sk_destruct(sk);
 }
diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c
index 3312a58..dbd9e65 100644
--- a/net/core/sock_diag.c
+++ b/net/core/sock_diag.c
@@ -116,14 +116,14 @@ static size_t sock_diag_nlmsg_size(void)
               + nla_total_size_64bit(sizeof(struct tcp_info))); /* 
INET_DIAG_INFO */
 }
 
-static void sock_diag_broadcast_destroy_work(struct work_struct *work)
+static void sock_diag_broadcast_work(struct work_struct *work)
 {
        struct broadcast_sk *bsk =
                container_of(work, struct broadcast_sk, work);
        struct sock *sk = bsk->sk;
        const struct sock_diag_handler *hndl;
        struct sk_buff *skb;
-       const enum sknetlink_groups group = sock_diag_destroy_group(sk);
+       const enum sknetlink_groups group = sock_diag_group(sk);
        int err = -1;
 
        WARN_ON(group == SKNLGRP_NONE);
@@ -144,11 +144,12 @@ static void sock_diag_broadcast_destroy_work(struct 
work_struct *work)
        else
                kfree_skb(skb);
 out:
-       sk_destruct(sk);
+       if (group <= SKNLGRP_INET6_UDP_DESTROY)
+               sk_destruct(sk);
        kfree(bsk);
 }
 
-void sock_diag_broadcast_destroy(struct sock *sk)
+void sock_diag_broadcast(struct sock *sk)
 {
        /* Note, this function is often called from an interrupt context. */
        struct broadcast_sk *bsk =
@@ -156,7 +157,7 @@ void sock_diag_broadcast_destroy(struct sock *sk)
        if (!bsk)
                return sk_destruct(sk);
        bsk->sk = sk;
-       INIT_WORK(&bsk->work, sock_diag_broadcast_destroy_work);
+       INIT_WORK(&bsk->work, sock_diag_broadcast_work);
        queue_work(broadcast_wq, &bsk->work);
 }
 
-- 
1.7.1

Reply via email to