Module Name:    src
Committed By:   ozaki-r
Date:           Wed Aug  2 01:28:03 UTC 2017

Modified Files:
        src/sys/netinet6: ip6_forward.c ip6_output.c
        src/sys/netipsec: ipsec.c ipsec.h key.c key.h xform_ah.c xform_esp.c
            xform_ipcomp.c
        src/sys/rump/librump/rumpnet: net_stub.c

Log Message:
Make IPsec SPD MP-safe

We use localcount(9), not psref(9), to make the sptree and secpolicy (SP)
entries MP-safe because SPs need to be referenced over opencrypto
processing that executes a callback in a different context.

SPs on sockets aren't managed by the sptree and can be destroyed in softint.
localcount_drain cannot be used in softint so we delay the destruction of
such SPs to a thread context. To do so, a list to manage such SPs is added
(key_socksplist) and key_timehandler_spd deletes dead SPs in the list.

For more details please read the locking notes in key.c.

Proposed on tech-kern@ and tech-net@


To generate a diff of this commit:
cvs rdiff -u -r1.87 -r1.88 src/sys/netinet6/ip6_forward.c
cvs rdiff -u -r1.192 -r1.193 src/sys/netinet6/ip6_output.c
cvs rdiff -u -r1.112 -r1.113 src/sys/netipsec/ipsec.c
cvs rdiff -u -r1.57 -r1.58 src/sys/netipsec/ipsec.h
cvs rdiff -u -r1.196 -r1.197 src/sys/netipsec/key.c
cvs rdiff -u -r1.25 -r1.26 src/sys/netipsec/key.h
cvs rdiff -u -r1.69 -r1.70 src/sys/netipsec/xform_ah.c
cvs rdiff -u -r1.67 -r1.68 src/sys/netipsec/xform_esp.c
cvs rdiff -u -r1.48 -r1.49 src/sys/netipsec/xform_ipcomp.c
cvs rdiff -u -r1.26 -r1.27 src/sys/rump/librump/rumpnet/net_stub.c

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/sys/netinet6/ip6_forward.c
diff -u src/sys/netinet6/ip6_forward.c:1.87 src/sys/netinet6/ip6_forward.c:1.88
--- src/sys/netinet6/ip6_forward.c:1.87	Tue May  9 04:24:10 2017
+++ src/sys/netinet6/ip6_forward.c	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: ip6_forward.c,v 1.87 2017/05/09 04:24:10 ozaki-r Exp $	*/
+/*	$NetBSD: ip6_forward.c,v 1.88 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$KAME: ip6_forward.c,v 1.109 2002/09/11 08:10:17 sakane Exp $	*/
 
 /*
@@ -31,7 +31,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: ip6_forward.c,v 1.87 2017/05/09 04:24:10 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: ip6_forward.c,v 1.88 2017/08/02 01:28:03 ozaki-r Exp $");
 
 #ifdef _KERNEL_OPT
 #include "opt_gateway.h"
@@ -462,7 +462,7 @@ ip6_forward(struct mbuf *m, int srcrt)
  out:
 #ifdef IPSEC
 	if (sp != NULL)
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 #endif
 	rtcache_unref(rt, ro);
 	if (ro != NULL)

Index: src/sys/netinet6/ip6_output.c
diff -u src/sys/netinet6/ip6_output.c:1.192 src/sys/netinet6/ip6_output.c:1.193
--- src/sys/netinet6/ip6_output.c:1.192	Mon Jun 26 08:01:53 2017
+++ src/sys/netinet6/ip6_output.c	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: ip6_output.c,v 1.192 2017/06/26 08:01:53 ozaki-r Exp $	*/
+/*	$NetBSD: ip6_output.c,v 1.193 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$KAME: ip6_output.c,v 1.172 2001/03/25 09:55:56 itojun Exp $	*/
 
 /*
@@ -62,7 +62,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: ip6_output.c,v 1.192 2017/06/26 08:01:53 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: ip6_output.c,v 1.193 2017/08/02 01:28:03 ozaki-r Exp $");
 
 #ifdef _KERNEL_OPT
 #include "opt_inet.h"
@@ -1069,7 +1069,7 @@ done:
 
 #ifdef IPSEC
 	if (sp != NULL)
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 #endif /* IPSEC */
 
 	if_put(ifp, &psref);

Index: src/sys/netipsec/ipsec.c
diff -u src/sys/netipsec/ipsec.c:1.112 src/sys/netipsec/ipsec.c:1.113
--- src/sys/netipsec/ipsec.c:1.112	Wed Jul 26 07:39:54 2017
+++ src/sys/netipsec/ipsec.c	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: ipsec.c,v 1.112 2017/07/26 07:39:54 ozaki-r Exp $	*/
+/*	$NetBSD: ipsec.c,v 1.113 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$FreeBSD: /usr/local/www/cvsroot/FreeBSD/src/sys/netipsec/ipsec.c,v 1.2.2.2 2003/07/01 01:38:13 sam Exp $	*/
 /*	$KAME: ipsec.c,v 1.103 2001/05/24 07:14:18 sakane Exp $	*/
 
@@ -32,7 +32,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: ipsec.c,v 1.112 2017/07/26 07:39:54 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: ipsec.c,v 1.113 2017/08/02 01:28:03 ozaki-r Exp $");
 
 /*
  * IPsec controller part.
@@ -59,6 +59,7 @@ __KERNEL_RCSID(0, "$NetBSD: ipsec.c,v 1.
 #include <sys/kauth.h>
 #include <sys/cpu.h>
 #include <sys/kmem.h>
+#include <sys/pserialize.h>
 
 #include <net/if.h>
 #include <net/route.h>
@@ -201,6 +202,7 @@ static struct secpolicy *ipsec_deepcopy_
 static int ipsec_set_policy (struct secpolicy **, int, const void *, size_t,
     kauth_cred_t);
 static int ipsec_get_policy (struct secpolicy *, struct mbuf **);
+static void ipsec_destroy_policy(struct secpolicy *);
 static void vshiftl (unsigned char *, int, int);
 static size_t ipsec_hdrsiz (const struct secpolicy *);
 
@@ -211,34 +213,49 @@ static struct secpolicy *
 ipsec_checkpcbcache(struct mbuf *m, struct inpcbpolicy *pcbsp, int dir)
 {
 	struct secpolicyindex spidx;
+	struct secpolicy *sp = NULL;
+	int s;
 
 	KASSERT(IPSEC_DIR_IS_VALID(dir));
 	KASSERT(pcbsp != NULL);
 	KASSERT(dir < __arraycount(pcbsp->sp_cache));
 	KASSERT(inph_locked(pcbsp->sp_inph));
 
+	/*
+	 * Checking the generation and sp->state and taking a reference to an SP
+	 * must be in a critical section of pserialize. See key_unlink_sp.
+	 */
+	s = pserialize_read_enter();
 	/* SPD table change invalidate all the caches. */
 	if (ipsec_spdgen != pcbsp->sp_cache[dir].cachegen) {
 		ipsec_invalpcbcache(pcbsp, dir);
-		return NULL;
+		goto out;
 	}
-	if (!pcbsp->sp_cache[dir].cachesp)
-		return NULL;
-	if (pcbsp->sp_cache[dir].cachesp->state != IPSEC_SPSTATE_ALIVE) {
+	sp = pcbsp->sp_cache[dir].cachesp;
+	if (sp == NULL)
+		goto out;
+	if (sp->state != IPSEC_SPSTATE_ALIVE) {
+		sp = NULL;
 		ipsec_invalpcbcache(pcbsp, dir);
-		return NULL;
+		goto out;
 	}
 	if ((pcbsp->sp_cacheflags & IPSEC_PCBSP_CONNECTED) == 0) {
-		if (ipsec_setspidx(m, &spidx, 1) != 0)
-			return NULL;
+		/* NB: assume ipsec_setspidx never sleep */
+		if (ipsec_setspidx(m, &spidx, 1) != 0) {
+			sp = NULL;
+			goto out;
+		}
 
 		/*
 		 * We have to make an exact match here since the cached rule
 		 * might have lower priority than a rule that would otherwise
 		 * have matched the packet. 
 		 */
-		if (memcmp(&pcbsp->sp_cache[dir].cacheidx, &spidx, sizeof(spidx))) 
-			return NULL;
+		if (memcmp(&pcbsp->sp_cache[dir].cacheidx, &spidx,
+		    sizeof(spidx))) {
+			sp = NULL;
+			goto out;
+		}
 	} else {
 		/*
 		 * The pcb is connected, and the L4 code is sure that:
@@ -252,13 +269,14 @@ ipsec_checkpcbcache(struct mbuf *m, stru
 		 */
 	}
 
-	pcbsp->sp_cache[dir].cachesp->lastused = time_second;
-	KEY_SP_REF(pcbsp->sp_cache[dir].cachesp);
+	sp->lastused = time_second;
+	KEY_SP_REF(sp);
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
 	    "DP cause refcnt++:%d SP:%p\n",
-	    key_sp_refcnt(pcbsp->sp_cache[dir].cachesp),
-	    pcbsp->sp_cache[dir].cachesp);
-	return pcbsp->sp_cache[dir].cachesp;
+	    key_sp_refcnt(sp), pcbsp->sp_cache[dir].cachesp);
+out:
+	pserialize_read_exit(s);
+	return sp;
 }
 
 static int
@@ -270,8 +288,6 @@ ipsec_fillpcbcache(struct inpcbpolicy *p
 	KASSERT(dir < __arraycount(pcbsp->sp_cache));
 	KASSERT(inph_locked(pcbsp->sp_inph));
 
-	if (pcbsp->sp_cache[dir].cachesp)
-		KEY_FREESP(&pcbsp->sp_cache[dir].cachesp);
 	pcbsp->sp_cache[dir].cachesp = NULL;
 	pcbsp->sp_cache[dir].cachehint = IPSEC_PCBHINT_UNKNOWN;
 	if (ipsec_setspidx(m, &pcbsp->sp_cache[dir].cacheidx, 1) != 0) {
@@ -279,7 +295,6 @@ ipsec_fillpcbcache(struct inpcbpolicy *p
 	}
 	pcbsp->sp_cache[dir].cachesp = sp;
 	if (pcbsp->sp_cache[dir].cachesp) {
-		KEY_SP_REF(pcbsp->sp_cache[dir].cachesp);
 		KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
 		    "DP cause refcnt++:%d SP:%p\n",
 		    key_sp_refcnt(pcbsp->sp_cache[dir].cachesp),
@@ -317,8 +332,6 @@ ipsec_invalpcbcache(struct inpcbpolicy *
 	for (i = IPSEC_DIR_INBOUND; i <= IPSEC_DIR_OUTBOUND; i++) {
 		if (dir != IPSEC_DIR_ANY && i != dir)
 			continue;
-		if (pcbsp->sp_cache[i].cachesp)
-			KEY_FREESP(&pcbsp->sp_cache[i].cachesp);
 		pcbsp->sp_cache[i].cachesp = NULL;
 		pcbsp->sp_cache[i].cachehint = IPSEC_PCBHINT_UNKNOWN;
 		pcbsp->sp_cache[i].cachegen = 0;
@@ -609,7 +622,7 @@ ipsec4_checkpolicy(struct mbuf *m, u_int
 		break;
 	case IPSEC_POLICY_BYPASS:
 	case IPSEC_POLICY_NONE:
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 		sp = NULL;		/* NB: force NULL result */
 		break;
 	case IPSEC_POLICY_IPSEC:
@@ -617,7 +630,7 @@ ipsec4_checkpolicy(struct mbuf *m, u_int
 		break;
 	}
 	if (*error != 0) {
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 		sp = NULL;
 		IPSECLOG(LOG_DEBUG, "done, error %d\n", *error);
 	}
@@ -697,7 +710,7 @@ ipsec4_output(struct mbuf *m, struct inp
 		 */
 		*mtu = _mtu;
 		*natt_frag = true;
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 		splx(s);
 		return 0;
 	}
@@ -711,7 +724,7 @@ ipsec4_output(struct mbuf *m, struct inp
 	 */
 	if (error == ENOENT)
 		error = 0;
-	KEY_FREESP(&sp);
+	KEY_SP_UNREF(&sp);
 	splx(s);
 	*done = true;
 	return error;
@@ -734,7 +747,7 @@ ipsec4_input(struct mbuf *m, int flags)
 	 * Check security policy against packet attributes.
 	 */
 	error = ipsec_in_reject(sp, m);
-	KEY_FREESP(&sp);
+	KEY_SP_UNREF(&sp);
 	splx(s);
 	if (error) {
 		return error;
@@ -753,7 +766,7 @@ ipsec4_input(struct mbuf *m, int flags)
 	sp = ipsec4_checkpolicy(m, IPSEC_DIR_OUTBOUND, flags, &error, NULL);
 	if (sp != NULL) {
 		m->m_flags &= ~M_CANFASTFWD;
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 	}
 	splx(s);
 	return 0;
@@ -802,7 +815,7 @@ ipsec4_forward(struct mbuf *m, int *dest
 		rtcache_unref(rt, ro);
 		KEY_FREESAV(&sav);
 	}
-	KEY_FREESP(&sp);
+	KEY_SP_UNREF(&sp);
 	return 0;
 }
 
@@ -838,7 +851,7 @@ ipsec6_checkpolicy(struct mbuf *m, u_int
 		break;
 	case IPSEC_POLICY_BYPASS:
 	case IPSEC_POLICY_NONE:
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 		sp = NULL;	  /* NB: force NULL result */
 		break;
 	case IPSEC_POLICY_IPSEC:
@@ -846,7 +859,7 @@ ipsec6_checkpolicy(struct mbuf *m, u_int
 		break;
 	}
 	if (*error != 0) {
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 		sp = NULL;
 		IPSECLOG(LOG_DEBUG, "done, error %d\n", *error);
 	}
@@ -1236,20 +1249,26 @@ ipsec_init_policy(struct socket *so, str
 	else
 		new->priv = 0;
 
+	/*
+	 * These SPs are dummy. Never be used because the policy
+	 * is ENTRUST. See ipsec_getpolicybysock.
+	 */
 	if ((new->sp_in = KEY_NEWSP()) == NULL) {
 		ipsec_delpcbpolicy(new);
 		return ENOBUFS;
 	}
 	new->sp_in->state = IPSEC_SPSTATE_ALIVE;
 	new->sp_in->policy = IPSEC_POLICY_ENTRUST;
+	new->sp_in->created = 0; /* Indicates dummy */
 
 	if ((new->sp_out = KEY_NEWSP()) == NULL) {
-		KEY_FREESP(&new->sp_in);
+		KEY_SP_UNREF(&new->sp_in);
 		ipsec_delpcbpolicy(new);
 		return ENOBUFS;
 	}
 	new->sp_out->state = IPSEC_SPSTATE_ALIVE;
 	new->sp_out->policy = IPSEC_POLICY_ENTRUST;
+	new->sp_out->created = 0; /* Indicates dummy */
 
 	*policy = new;
 
@@ -1264,14 +1283,14 @@ ipsec_copy_policy(const struct inpcbpoli
 
 	sp = ipsec_deepcopy_policy(old->sp_in);
 	if (sp) {
-		KEY_FREESP(&new->sp_in);
+		KEY_SP_UNREF(&new->sp_in);
 		new->sp_in = sp;
 	} else
 		return ENOBUFS;
 
 	sp = ipsec_deepcopy_policy(old->sp_out);
 	if (sp) {
-		KEY_FREESP(&new->sp_out);
+		KEY_SP_UNREF(&new->sp_out);
 		new->sp_out = sp;
 	} else
 		return ENOBUFS;
@@ -1326,6 +1345,23 @@ ipsec_deepcopy_policy(const struct secpo
 	return dst;
 }
 
+static void
+ipsec_destroy_policy(struct secpolicy *sp)
+{
+
+	if (sp->created == 0)
+		/* It's dummy. We can simply free it */
+		key_free_sp(sp);
+	else {
+		/*
+		 * We cannot destroy here because it can be called in
+		 * softint. So mark the SP as DEAD and let the timer
+		 * destroy it. See key_timehandler_spd.
+		 */
+		sp->state = IPSEC_SPSTATE_DEAD;
+	}
+}
+
 /* set policy and ipsec request if present. */
 static int
 ipsec_set_policy(
@@ -1337,7 +1373,7 @@ ipsec_set_policy(
 )
 {
 	const struct sadb_x_policy *xpl;
-	struct secpolicy *newsp = NULL;
+	struct secpolicy *newsp = NULL, *oldsp;
 	int error;
 
 	KASSERT(!cpu_softintr_p());
@@ -1372,11 +1408,16 @@ ipsec_set_policy(
 	if ((newsp = key_msg2sp(xpl, len, &error)) == NULL)
 		return error;
 
-	newsp->state = IPSEC_SPSTATE_ALIVE;
+	key_init_sp(newsp);
+	newsp->created = time_uptime;
+	/* Insert the global list for SPs for sockets */
+	key_socksplist_add(newsp);
 
 	/* clear old SP and set new SP */
-	KEY_FREESP(policy);
+	oldsp = *policy;
 	*policy = newsp;
+	ipsec_destroy_policy(oldsp);
+
 	if (KEYDEBUG_ON(KEYDEBUG_IPSEC_DUMP)) {
 		printf("%s: new policy\n", __func__);
 		kdebug_secpolicy(newsp);
@@ -1416,6 +1457,7 @@ ipsec4_set_policy(struct inpcb *inp, int
 	struct secpolicy **policy;
 
 	KASSERT(!cpu_softintr_p());
+	KASSERT(inp_locked(inp));
 
 	/* sanity check. */
 	if (inp == NULL || request == NULL)
@@ -1486,10 +1528,10 @@ ipsec4_delete_pcbpolicy(struct inpcb *in
 		return 0;
 
 	if (inp->inp_sp->sp_in != NULL)
-		KEY_FREESP(&inp->inp_sp->sp_in);
+		ipsec_destroy_policy(inp->inp_sp->sp_in);
 
 	if (inp->inp_sp->sp_out != NULL)
-		KEY_FREESP(&inp->inp_sp->sp_out);
+		ipsec_destroy_policy(inp->inp_sp->sp_out);
 
 	ipsec_invalpcbcache(inp->inp_sp, IPSEC_DIR_ANY);
 
@@ -1508,6 +1550,7 @@ ipsec6_set_policy(struct in6pcb *in6p, i
 	struct secpolicy **policy;
 
 	KASSERT(!cpu_softintr_p());
+	KASSERT(in6p_locked(in6p));
 
 	/* sanity check. */
 	if (in6p == NULL || request == NULL)
@@ -1575,10 +1618,10 @@ ipsec6_delete_pcbpolicy(struct in6pcb *i
 		return 0;
 
 	if (in6p->in6p_sp->sp_in != NULL)
-		KEY_FREESP(&in6p->in6p_sp->sp_in);
+		ipsec_destroy_policy(in6p->in6p_sp->sp_in);
 
 	if (in6p->in6p_sp->sp_out != NULL)
-		KEY_FREESP(&in6p->in6p_sp->sp_out);
+		ipsec_destroy_policy(in6p->in6p_sp->sp_out);
 
 	ipsec_invalpcbcache(in6p->in6p_sp, IPSEC_DIR_ANY);
 
@@ -1778,7 +1821,7 @@ ipsec4_in_reject(struct mbuf *m, struct 
 		result = ipsec_in_reject(sp, m);
 		if (result)
 			IPSEC_STATINC(IPSEC_STAT_IN_POLVIO);
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 	} else {
 		result = 0;	/* XXX should be panic ?
 				 * -> No, there may be error. */
@@ -1817,7 +1860,7 @@ ipsec6_in_reject(struct mbuf *m, struct 
 		result = ipsec_in_reject(sp, m);
 		if (result)
 			IPSEC_STATINC(IPSEC_STAT_IN_POLVIO);
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 	} else {
 		result = 0;
 	}
@@ -1929,7 +1972,7 @@ ipsec4_hdrsiz(struct mbuf *m, u_int dir,
 		KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_DATA, "size:%lu.\n",
 		    (unsigned long)size);
 
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 	} else {
 		size = 0;	/* XXX should be panic ? */
 	}
@@ -1964,7 +2007,7 @@ ipsec6_hdrsiz(struct mbuf *m, u_int dir,
 		return 0;
 	size = ipsec_hdrsiz(sp);
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_DATA, "size:%zu.\n", size);
-	KEY_FREESP(&sp);
+	KEY_SP_UNREF(&sp);
 
 	return size;
 }
@@ -2279,7 +2322,7 @@ ipsec6_input(struct mbuf *m)
 		 * attributes.
 		 */
 		error = ipsec_in_reject(sp, m);
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 	} else {
 		/* XXX error stat??? */
 		error = EINVAL;

Index: src/sys/netipsec/ipsec.h
diff -u src/sys/netipsec/ipsec.h:1.57 src/sys/netipsec/ipsec.h:1.58
--- src/sys/netipsec/ipsec.h:1.57	Wed Jul 26 09:18:15 2017
+++ src/sys/netipsec/ipsec.h	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: ipsec.h,v 1.57 2017/07/26 09:18:15 ozaki-r Exp $	*/
+/*	$NetBSD: ipsec.h,v 1.58 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$FreeBSD: /usr/local/www/cvsroot/FreeBSD/src/sys/netipsec/ipsec.h,v 1.2.4.2 2004/02/14 22:23:23 bms Exp $	*/
 /*	$KAME: ipsec.h,v 1.53 2001/11/20 08:32:38 itojun Exp $	*/
 
@@ -47,6 +47,7 @@
 
 #ifdef _KERNEL
 #include <sys/socketvar.h>
+#include <sys/localcount.h>
 
 #include <netinet/in_pcb_hdr.h>
 #include <netipsec/keydb.h>
@@ -76,7 +77,7 @@ struct secpolicyindex {
 struct secpolicy {
 	struct pslist_entry pslist_entry;
 
-	u_int refcnt;			/* reference count */
+	struct localcount localcount;	/* reference count */
 	struct secpolicyindex spidx;	/* selector */
 	u_int32_t id;			/* It's unique number on the system. */
 	u_int state;			/* 0: dead, others: alive */

Index: src/sys/netipsec/key.c
diff -u src/sys/netipsec/key.c:1.196 src/sys/netipsec/key.c:1.197
--- src/sys/netipsec/key.c:1.196	Thu Jul 27 09:53:57 2017
+++ src/sys/netipsec/key.c	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: key.c,v 1.196 2017/07/27 09:53:57 ozaki-r Exp $	*/
+/*	$NetBSD: key.c,v 1.197 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/key.c,v 1.3.2.3 2004/02/14 22:23:23 bms Exp $	*/
 /*	$KAME: key.c,v 1.191 2001/06/27 10:46:49 sakane Exp $	*/
 
@@ -32,7 +32,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: key.c,v 1.196 2017/07/27 09:53:57 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: key.c,v 1.197 2017/08/02 01:28:03 ozaki-r Exp $");
 
 /*
  * This code is referd to RFC 2367
@@ -42,6 +42,7 @@ __KERNEL_RCSID(0, "$NetBSD: key.c,v 1.19
 #include "opt_inet.h"
 #include "opt_ipsec.h"
 #include "opt_gateway.h"
+#include "opt_net_mpsafe.h"
 #endif
 
 #include <sys/types.h>
@@ -67,6 +68,10 @@ __KERNEL_RCSID(0, "$NetBSD: key.c,v 1.19
 #include <sys/cpu.h>
 #include <sys/atomic.h>
 #include <sys/pslist.h>
+#include <sys/mutex.h>
+#include <sys/condvar.h>
+#include <sys/localcount.h>
+#include <sys/pserialize.h>
 
 #include <net/if.h>
 #include <net/route.h>
@@ -130,6 +135,50 @@ percpu_t *pfkeystat_percpu;
  *   field hits 0 (= no external reference other than from SA header.
  */
 
+/*
+ * Locking notes on SPD:
+ * - Modifications to the sptree must be done with holding key_sp_mtx
+ *   which is a adaptive mutex
+ * - Read accesses to the sptree must be in critical sections of pserialize(9)
+ * - SP's lifetime is managed by localcount(9)
+ * - An SP that has been inserted to the sptree is initially referenced by none,
+ *   i.e., a reference from the pstree isn't counted
+ * - When an SP is being destroyed, we change its state as DEAD, wait for
+ *   references to the SP to be released, and then deallocate the SP
+ *   (see key_unlink_sp)
+ * - Getting an SP
+ *   - Normally we get an SP from the sptree by incrementing the reference count
+ *     of the SP
+ *   - We can gain another reference from a held SP only if we check its state
+ *     and take its reference in a critical section of pserialize
+ *     (see esp_output for example)
+ *   - We may get an SP from an SP cache. See below
+ * - Updating member variables of an SP
+ *   - Most member variables of an SP are immutable
+ *   - Only sp->state and sp->lastused can be changed
+ *   - sp->state of an SP is updated only when destroying it under key_sp_mtx
+ * - SP caches
+ *   - SPs can be cached in PCBs
+ *   - The lifetime of the caches is controlled by the global generation counter
+ *     (ipsec_spdgen)
+ *   - The global counter value is stored when an SP is cached
+ *   - If the stored value is different from the global counter then the cache
+ *     is considered invalidated
+ *   - The counter is incremented when an SP is being destroyed
+ *   - So checking the generation and taking a reference to an SP should be
+ *     in a critical section of pserialize
+ *   - Note that caching doesn't increment the reference counter of an SP
+ * - SPs in sockets
+ *   - Userland programs can set a policy to a socket by
+ *     setsockopt(IP_IPSEC_POLICY)
+ *   - Such policies (SPs) are set to a socket (PCB) and also inserted to
+ *     the key_socksplist list (not the sptree)
+ *   - Such a policy is destroyed when a corresponding socket is destroed,
+ *     however, a socket can be destroyed in softint so we cannot destroy
+ *     it directly instead we just mark it DEAD and delay the destruction
+ *     until GC by the timer
+ */
+
 u_int32_t key_debug_level = 0;
 static u_int key_spi_trycnt = 1000;
 static u_int32_t key_spi_minval = 0x100;
@@ -195,9 +244,22 @@ static LIST_HEAD(_spacqtree, secspacq) s
 	} while (0)
 
 /*
+ * The list has SPs that are set to a socket via setsockopt(IP_IPSEC_POLICY)
+ * from userland. See ipsec_set_policy.
+ */
+static struct pslist_head key_socksplist;
+
+#define SOCKSPLIST_WRITER_FOREACH(sp)					\
+	PSLIST_WRITER_FOREACH((sp), &key_socksplist, struct secpolicy,	\
+	                      pslist_entry)
+
+/*
  * Protect regtree, acqtree and items stored in the lists.
  */
 static kmutex_t key_mtx __cacheline_aligned;
+static pserialize_t key_psz;
+static kmutex_t key_sp_mtx __cacheline_aligned;
+static kcondvar_t key_sp_cv __cacheline_aligned;
 
 /* search order for SAs */
 	/*
@@ -432,9 +494,11 @@ static struct secasvar *key_lookup_sa_by
 static void key_freeso(struct socket *);
 static void key_freesp_so(struct secpolicy **);
 #endif
-static void key_delsp (struct secpolicy *);
 static struct secpolicy *key_getsp (const struct secpolicyindex *);
 static struct secpolicy *key_getspbyid (u_int32_t);
+static struct secpolicy *key_lookup_and_remove_sp(const struct secpolicyindex *);
+static struct secpolicy *key_lookupbyid_and_remove_sp(u_int32_t);
+static void key_destroy_sp(struct secpolicy *);
 static u_int16_t key_newreqid (void);
 static struct mbuf *key_gather_mbuf (struct mbuf *,
 	const struct sadb_msghdr *, int, int, ...);
@@ -582,8 +646,6 @@ static const char *key_getfqdn (void);
 static const char *key_getuserfqdn (void);
 #endif
 static void key_sa_chgstate (struct secasvar *, u_int8_t);
-static inline void key_sp_dead (struct secpolicy *);
-static void key_sp_unlink (struct secpolicy *sp);
 
 static struct mbuf *key_alloc_mbuf (int);
 
@@ -622,55 +684,37 @@ static struct work	key_timehandler_wk;
 	REFLOG("SA_DELREF", (p), (where), (tag));			\
 } while (0)
 
-#define	SP_ADDREF(p) do {						\
-	atomic_inc_uint(&(p)->refcnt);					\
-	REFLOG("SP_ADDREF", (p), __func__, __LINE__);			\
-	KASSERTMSG((p)->refcnt != 0, "SP refcnt overflow");		\
-} while (0)
-#define	SP_ADDREF2(p, where, tag) do {					\
-	atomic_inc_uint(&(p)->refcnt);					\
-	REFLOG("SP_ADDREF", (p), (where), (tag));			\
-	KASSERTMSG((p)->refcnt != 0, "SP refcnt overflow");		\
-} while (0)
-#define	SP_DELREF(p) do {						\
-	KASSERTMSG((p)->refcnt > 0, "SP refcnt underflow");		\
-	atomic_dec_uint(&(p)->refcnt);					\
-	REFLOG("SP_DELREF", (p), __func__, __LINE__);			\
-} while (0)
-#define	SP_DELREF2(p, nv, where, tag) do {				\
-	KASSERTMSG((p)->refcnt > 0, "SP refcnt underflow");		\
-	nv = atomic_dec_uint_nv(&(p)->refcnt);				\
-	REFLOG("SP_DELREF", (p), (where), (tag));			\
-} while (0)
-
 u_int
 key_sp_refcnt(const struct secpolicy *sp)
 {
 
-	if (sp == NULL)
-		return 0;
-
-	return sp->refcnt;
+	/* FIXME */
+	return 0;
 }
 
-static inline void
-key_sp_dead(struct secpolicy *sp)
+/*
+ * Remove the sp from the sptree and wait for references to the sp
+ * to be released. key_sp_mtx must be held.
+ */
+static void
+key_unlink_sp(struct secpolicy *sp)
 {
 
-	/* mark the SP dead */
+	KASSERT(mutex_owned(&key_sp_mtx));
+
 	sp->state = IPSEC_SPSTATE_DEAD;
-}
+	SPLIST_WRITER_REMOVE(sp);
 
-static void
-key_sp_unlink(struct secpolicy *sp)
-{
+	/* Invalidate all cached SPD pointers in the PCBs. */
+	ipsec_invalpcbcacheall();
 
-	/* remove from SP index */
-	SPLIST_WRITER_REMOVE(sp);
-	/* Release refcount held just for being on chain */
-	KEY_FREESP(&sp);
-}
+#ifdef NET_MPSAFE
+	KASSERT(mutex_ownable(softnet_lock));
+	pserialize_perform(key_psz);
+#endif
 
+	localcount_drain(&sp->localcount, &key_sp_cv, &key_sp_mtx);
+}
 
 /*
  * Return 0 when there are known to be no SP's for the specified
@@ -704,12 +748,12 @@ key_lookup_sp_byspidx(const struct secpo
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP, "DP from %s:%u\n", where, tag);
 
 	/* get a SP entry */
-	s = splsoftnet();	/*called from softclock()*/
 	if (KEYDEBUG_ON(KEYDEBUG_IPSEC_DATA)) {
 		printf("*** objects\n");
 		kdebug_secpolicyindex(spidx);
 	}
 
+	s = pserialize_read_enter();
 	SPLIST_READER_FOREACH(sp, dir) {
 		if (KEYDEBUG_ON(KEYDEBUG_IPSEC_DATA)) {
 			printf("*** in SPD\n");
@@ -729,9 +773,9 @@ found:
 
 		/* found a SPD entry */
 		sp->lastused = time_uptime;
-		SP_ADDREF2(sp, where, tag);
+		key_sp_ref(sp, where, tag);
 	}
-	splx(s);
+	pserialize_read_exit(s);
 
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
 	    "DP return SP:%p (ID=%u) refcnt %u\n",
@@ -765,7 +809,7 @@ key_gettunnel(const struct sockaddr *osr
 		goto done;
 	}
 
-	s = splsoftnet();	/*called from softclock()*/
+	s = pserialize_read_enter();
 	SPLIST_READER_FOREACH(sp, dir) {
 		if (sp->state == IPSEC_SPSTATE_DEAD)
 			continue;
@@ -805,9 +849,9 @@ key_gettunnel(const struct sockaddr *osr
 found:
 	if (sp) {
 		sp->lastused = time_uptime;
-		SP_ADDREF2(sp, where, tag);
+		key_sp_ref(sp, where, tag);
 	}
-	splx(s);
+	pserialize_read_exit(s);
 done:
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
 	    "DP return SP:%p (ID=%u) refcnt %u\n",
@@ -1154,49 +1198,51 @@ key_validate_savlist(const struct secash
 }
 
 void
+key_init_sp(struct secpolicy *sp)
+{
+
+	ASSERT_SLEEPABLE();
+
+	sp->state = IPSEC_SPSTATE_ALIVE;
+	if (sp->policy == IPSEC_POLICY_IPSEC)
+		KASSERT(sp->req != NULL);
+	localcount_init(&sp->localcount);
+	SPLIST_ENTRY_INIT(sp);
+}
+
+void
 key_sp_ref(struct secpolicy *sp, const char* where, int tag)
 {
 
-	SP_ADDREF2(sp, where, tag);
+	localcount_acquire(&sp->localcount);
 
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
-	    "DP SP:%p (ID=%u) from %s:%u; refcnt now %u\n",
+	    "DP SP:%p (ID=%u) from %s:%u; refcnt++ now %u\n",
 	    sp, sp->id, where, tag, key_sp_refcnt(sp));
 }
 
 void
-key_sa_ref(struct secasvar *sav, const char* where, int tag)
+key_sp_unref(struct secpolicy *sp, const char* where, int tag)
 {
 
-	SA_ADDREF2(sav, where, tag);
+	KASSERT(mutex_ownable(&key_sp_mtx));
 
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
-	    "DP cause refcnt++:%d SA:%p from %s:%u\n",
-	    sav->refcnt, sav, where, tag);
+	    "DP SP:%p (ID=%u) from %s:%u; refcnt-- now %u\n",
+	    sp, sp->id, where, tag, key_sp_refcnt(sp));
+
+	localcount_release(&sp->localcount, &key_sp_cv, &key_sp_mtx);
 }
 
-/*
- * Must be called after calling key_lookup_sp*().
- * For both the packet without socket and key_freeso().
- */
 void
-_key_freesp(struct secpolicy **spp, const char* where, int tag)
+key_sa_ref(struct secasvar *sav, const char* where, int tag)
 {
-	struct secpolicy *sp = *spp;
-	unsigned int nv;
 
-	KASSERT(sp != NULL);
-
-	SP_DELREF2(sp, nv, where, tag);
+	SA_ADDREF2(sav, where, tag);
 
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
-	    "DP SP:%p (ID=%u) from %s:%u; refcnt now %u\n",
-	    sp, sp->id, where, tag, nv);
-
-	if (nv == 0) {
-		*spp = NULL;
-		key_delsp(sp);
-	}
+	    "DP cause refcnt++:%d SA:%p from %s:%u\n",
+	    sav->refcnt, sav, where, tag);
 }
 
 #if 0
@@ -1270,7 +1316,7 @@ key_freesp_so(struct secpolicy **sp)
 
 	KASSERTMSG((*sp)->policy == IPSEC_POLICY_IPSEC,
 	    "invalid policy %u", (*sp)->policy);
-	KEY_FREESP(sp);
+	KEY_SP_UNREF(&sp);
 }
 #endif
 
@@ -1309,20 +1355,18 @@ key_freesav(struct secasvar **psav, cons
  * free security policy entry.
  */
 static void
-key_delsp(struct secpolicy *sp)
+key_destroy_sp(struct secpolicy *sp)
 {
-	int s;
-
-	KASSERT(sp != NULL);
 
-	key_sp_dead(sp);
-
-	KASSERTMSG(sp->refcnt == 0,
-	    "SP with references deleted (refcnt %u)", sp->refcnt);
+	SPLIST_ENTRY_DESTROY(sp);
+	localcount_fini(&sp->localcount);
 
-	s = splsoftnet();	/*called from softclock()*/
+	key_free_sp(sp);
+}
 
-    {
+void
+key_free_sp(struct secpolicy *sp)
+{
 	struct ipsecrequest *isr = sp->req, *nextisr;
 
 	while (isr != NULL) {
@@ -1330,12 +1374,17 @@ key_delsp(struct secpolicy *sp)
 		kmem_intr_free(isr, sizeof(*isr));
 		isr = nextisr;
 	}
-    }
 
-	SPLIST_ENTRY_DESTROY(sp);
 	kmem_intr_free(sp, sizeof(*sp));
+}
 
-	splx(s);
+void
+key_socksplist_add(struct secpolicy *sp)
+{
+
+	mutex_enter(&key_sp_mtx);
+	PSLIST_WRITER_INSERT_HEAD(&key_socksplist, sp, pslist_entry);
+	mutex_exit(&key_sp_mtx);
 }
 
 /*
@@ -1347,22 +1396,52 @@ static struct secpolicy *
 key_getsp(const struct secpolicyindex *spidx)
 {
 	struct secpolicy *sp;
+	int s;
 
 	KASSERT(spidx != NULL);
 
+	s = pserialize_read_enter();
 	SPLIST_READER_FOREACH(sp, spidx->dir) {
 		if (sp->state == IPSEC_SPSTATE_DEAD)
 			continue;
 		if (key_spidx_match_exactly(spidx, &sp->spidx)) {
-			SP_ADDREF(sp);
+			KEY_SP_REF(sp);
+			pserialize_read_exit(s);
 			return sp;
 		}
 	}
+	pserialize_read_exit(s);
 
 	return NULL;
 }
 
 /*
+ * search SPD and remove found SP
+ * OUT:	NULL	: not found
+ *	others	: found, pointer to a SP.
+ */
+static struct secpolicy *
+key_lookup_and_remove_sp(const struct secpolicyindex *spidx)
+{
+	struct secpolicy *sp = NULL;
+
+	mutex_enter(&key_sp_mtx);
+	SPLIST_WRITER_FOREACH(sp, spidx->dir) {
+		KASSERT(sp->state != IPSEC_SPSTATE_DEAD);
+
+		if (key_spidx_match_exactly(spidx, &sp->spidx)) {
+			key_unlink_sp(sp);
+			goto out;
+		}
+	}
+	sp = NULL;
+out:
+	mutex_exit(&key_sp_mtx);
+
+	return sp;
+}
+
+/*
  * get SP by index.
  * OUT:	NULL	: not found
  *	others	: found, pointer to a SP.
@@ -1371,13 +1450,15 @@ static struct secpolicy *
 key_getspbyid(u_int32_t id)
 {
 	struct secpolicy *sp;
+	int s;
 
+	s = pserialize_read_enter();
 	SPLIST_READER_FOREACH(sp, IPSEC_DIR_INBOUND) {
 		if (sp->state == IPSEC_SPSTATE_DEAD)
 			continue;
 		if (sp->id == id) {
-			SP_ADDREF(sp);
-			return sp;
+			KEY_SP_REF(sp);
+			goto out;
 		}
 	}
 
@@ -1385,12 +1466,42 @@ key_getspbyid(u_int32_t id)
 		if (sp->state == IPSEC_SPSTATE_DEAD)
 			continue;
 		if (sp->id == id) {
-			SP_ADDREF(sp);
-			return sp;
+			KEY_SP_REF(sp);
+			goto out;
 		}
 	}
+out:
+	pserialize_read_exit(s);
+	return sp;
+}
 
-	return NULL;
+/*
+ * get SP by index, remove and return it.
+ * OUT:	NULL	: not found
+ *	others	: found, pointer to a SP.
+ */
+static struct secpolicy *
+key_lookupbyid_and_remove_sp(u_int32_t id)
+{
+	struct secpolicy *sp;
+
+	mutex_enter(&key_sp_mtx);
+	SPLIST_READER_FOREACH(sp, IPSEC_DIR_INBOUND) {
+		KASSERT(sp->state != IPSEC_SPSTATE_DEAD);
+		if (sp->id == id)
+			goto out;
+	}
+
+	SPLIST_READER_FOREACH(sp, IPSEC_DIR_OUTBOUND) {
+		KASSERT(sp->state != IPSEC_SPSTATE_DEAD);
+		if (sp->id == id)
+			goto out;
+	}
+out:
+	if (sp != NULL)
+		key_unlink_sp(sp);
+	mutex_exit(&key_sp_mtx);
+	return sp;
 }
 
 struct secpolicy *
@@ -1399,8 +1510,6 @@ key_newsp(const char* where, int tag)
 	struct secpolicy *newsp = NULL;
 
 	newsp = kmem_intr_zalloc(sizeof(struct secpolicy), KM_NOSLEEP);
-	if (newsp != NULL)
-		newsp->refcnt = 1;
 
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
 	    "DP from %s:%u return SP:%p\n", where, tag, newsp);
@@ -1451,7 +1560,7 @@ key_msg2sp(const struct sadb_x_policy *x
 		break;
 	default:
 		IPSECLOG(LOG_DEBUG, "invalid policy type.\n");
-		KEY_FREESP(&newsp);
+		key_free_sp(newsp);
 		*error = EINVAL;
 		return NULL;
 	}
@@ -1605,7 +1714,7 @@ key_msg2sp(const struct sadb_x_policy *x
 	return newsp;
 
 free_exit:
-	KEY_FREESP(&newsp);
+	key_free_sp(newsp);
 	return NULL;
 }
 
@@ -1857,16 +1966,14 @@ key_api_spdadd(struct socket *so, struct
     {
 	struct secpolicy *sp;
 
-	sp = key_getsp(&spidx);
 	if (mhp->msg->sadb_msg_type == SADB_X_SPDUPDATE) {
-		if (sp) {
-			key_sp_dead(sp);
-			key_sp_unlink(sp);	/* XXX jrs ordering */
-			KEY_FREESP(&sp);
-		}
+		sp = key_lookup_and_remove_sp(&spidx);
+		if (sp != NULL)
+			key_destroy_sp(sp);
 	} else {
+		sp = key_getsp(&spidx);
 		if (sp != NULL) {
-			KEY_FREESP(&sp);
+			KEY_SP_UNREF(&sp);
 			IPSECLOG(LOG_DEBUG, "a SP entry exists already.\n");
 			return key_senderror(so, m, EEXIST);
 		}
@@ -1891,12 +1998,11 @@ key_api_spdadd(struct socket *so, struct
 	newsp->lifetime = lft ? lft->sadb_lifetime_addtime : 0;
 	newsp->validtime = lft ? lft->sadb_lifetime_usetime : 0;
 
-	newsp->refcnt = 1;	/* do not reclaim until I say I do */
-	newsp->state = IPSEC_SPSTATE_ALIVE;
-	if (newsp->policy == IPSEC_POLICY_IPSEC)
-		KASSERT(newsp->req != NULL);
-	SPLIST_ENTRY_INIT(newsp);
+	key_init_sp(newsp);
+
+	mutex_enter(&key_sp_mtx);
 	SPLIST_WRITER_INSERT_TAIL(newsp->spidx.dir, newsp);
+	mutex_exit(&key_sp_mtx);
 
 #ifdef notyet
 	/* delete the entry in spacqtree */
@@ -1984,7 +2090,7 @@ key_getnewspid(void)
 		if (sp == NULL)
 			break;
 
-		KEY_FREESP(&sp);
+		KEY_SP_UNREF(&sp);
 	}
 
 	if (count == 0 || newid == 0) {
@@ -2044,7 +2150,7 @@ key_api_spddelete(struct socket *so, str
 	key_init_spidx_bymsghdr(&spidx, mhp);
 
 	/* Is there SP in SPD ? */
-	sp = key_getsp(&spidx);
+	sp = key_lookup_and_remove_sp(&spidx);
 	if (sp == NULL) {
 		IPSECLOG(LOG_DEBUG, "no SP found.\n");
 		return key_senderror(so, m, EINVAL);
@@ -2053,12 +2159,7 @@ key_api_spddelete(struct socket *so, str
 	/* save policy id to buffer to be returned. */
 	xpl0->sadb_x_policy_id = sp->id;
 
-	key_sp_dead(sp);
-	key_sp_unlink(sp);	/* XXX jrs ordering */
-	KEY_FREESP(&sp);	/* ref gained by key_getspbyid */
-
-	/* Invalidate all cached SPD pointers in the PCBs. */
-	ipsec_invalpcbcacheall();
+	key_destroy_sp(sp);
 
 	/* We're deleting policy; no need to invalidate the ipflow cache. */
 
@@ -2109,19 +2210,13 @@ key_api_spddelete2(struct socket *so, st
 	id = ((struct sadb_x_policy *)mhp->ext[SADB_X_EXT_POLICY])->sadb_x_policy_id;
 
 	/* Is there SP in SPD ? */
-	sp = key_getspbyid(id);
+	sp = key_lookupbyid_and_remove_sp(id);
 	if (sp == NULL) {
 		IPSECLOG(LOG_DEBUG, "no SP found id:%u.\n", id);
 		return key_senderror(so, m, EINVAL);
 	}
 
-	key_sp_dead(sp);
-	key_sp_unlink(sp);	/* XXX jrs ordering */
-	KEY_FREESP(&sp);	/* ref gained by key_getsp */
-	sp = NULL;
-
-	/* Invalidate all cached SPD pointers in the PCBs. */
-	ipsec_invalpcbcacheall();
+	key_destroy_sp(sp);
 
 	/* We're deleting policy; no need to invalidate the ipflow cache. */
 
@@ -2211,7 +2306,7 @@ key_api_spdget(struct socket *so, struct
 
 	n = key_setdumpsp(sp, SADB_X_SPDGET, mhp->msg->sadb_msg_seq,
 	    mhp->msg->sadb_msg_pid);
-	KEY_FREESP(&sp); /* ref gained by key_getspbyid */
+	KEY_SP_UNREF(&sp); /* ref gained by key_getspbyid */
 	if (n != NULL) {
 		m_freem(m);
 		return key_sendup_mbuf(so, n, KEY_SENDUP_ONE);
@@ -2317,18 +2412,17 @@ key_api_spdflush(struct socket *so, stru
 
 	for (dir = 0; dir < IPSEC_DIR_MAX; dir++) {
 	    retry:
+		mutex_enter(&key_sp_mtx);
 		SPLIST_WRITER_FOREACH(sp, dir) {
-			if (sp->state == IPSEC_SPSTATE_DEAD)
-				continue;
-			key_sp_dead(sp);
-			key_sp_unlink(sp);
+			KASSERT(sp->state != IPSEC_SPSTATE_DEAD);
+			key_unlink_sp(sp);
+			mutex_exit(&key_sp_mtx);
+			key_destroy_sp(sp);
 			goto retry;
 		}
+		mutex_exit(&key_sp_mtx);
 	}
 
-	/* Invalidate all cached SPD pointers in the PCBs. */
-	ipsec_invalpcbcacheall();
-
 	/* We're deleting policy; no need to invalidate the ipflow cache. */
 
 	if (sizeof(struct sadb_msg) > m->m_len + M_TRAILINGSPACE(m)) {
@@ -2361,12 +2455,14 @@ key_setspddump_chain(int *errorp, int *l
 	struct mbuf *m, *n, *prev;
 	int totlen;
 
+	KASSERT(mutex_owned(&key_sp_mtx));
+
 	*lenp = 0;
 
 	/* search SPD entry and get buffer size. */
 	cnt = 0;
 	for (dir = 0; dir < IPSEC_DIR_MAX; dir++) {
-		SPLIST_READER_FOREACH(sp, dir) {
+		SPLIST_WRITER_FOREACH(sp, dir) {
 			cnt++;
 		}
 	}
@@ -2380,7 +2476,7 @@ key_setspddump_chain(int *errorp, int *l
 	prev = m;
 	totlen = 0;
 	for (dir = 0; dir < IPSEC_DIR_MAX; dir++) {
-		SPLIST_READER_FOREACH(sp, dir) {
+		SPLIST_WRITER_FOREACH(sp, dir) {
 			--cnt;
 			n = key_setdumpsp(sp, SADB_X_SPDDUMP, cnt, pid);
 
@@ -2423,7 +2519,7 @@ key_api_spddump(struct socket *so, struc
 {
 	struct mbuf *n;
 	int error, len;
-	int ok, s;
+	int ok;
 	pid_t pid;
 
 	pid = mhp->msg->sadb_msg_pid;
@@ -2438,9 +2534,9 @@ key_api_spddump(struct socket *so, struc
 		return key_senderror(so, m0, ENOBUFS);
 	}
 
-	s = splsoftnet();
+	mutex_enter(&key_sp_mtx);
 	n = key_setspddump_chain(&error, &len, pid);
-	splx(s);
+	mutex_exit(&key_sp_mtx);
 
 	if (n == NULL) {
 		return key_senderror(so, m0, ENOENT);
@@ -4341,24 +4437,37 @@ key_timehandler_spd(time_t now)
 
 	for (dir = 0; dir < IPSEC_DIR_MAX; dir++) {
 	    retry:
+		mutex_enter(&key_sp_mtx);
 		SPLIST_WRITER_FOREACH(sp, dir) {
-			if (sp->state == IPSEC_SPSTATE_DEAD) {
-				key_sp_unlink(sp);	/*XXX*/
-				goto retry;
-			}
+			KASSERT(sp->state != IPSEC_SPSTATE_DEAD);
 
 			if (sp->lifetime == 0 && sp->validtime == 0)
 				continue;
 
-			/* the deletion will occur next time */
 			if ((sp->lifetime && now - sp->created > sp->lifetime) ||
 			    (sp->validtime && now - sp->lastused > sp->validtime)) {
-			  	key_sp_dead(sp);
+				key_unlink_sp(sp);
+				mutex_exit(&key_sp_mtx);
 				key_spdexpire(sp);
+				key_destroy_sp(sp);
 				goto retry;
 			}
 		}
+		mutex_exit(&key_sp_mtx);
 	}
+
+    retry_socksplist:
+	mutex_enter(&key_sp_mtx);
+	SOCKSPLIST_WRITER_FOREACH(sp) {
+		if (sp->state != IPSEC_SPSTATE_DEAD)
+			continue;
+
+		key_unlink_sp(sp);
+		mutex_exit(&key_sp_mtx);
+		key_destroy_sp(sp);
+		goto retry_socksplist;
+	}
+	mutex_exit(&key_sp_mtx);
 }
 
 static void
@@ -7537,6 +7646,9 @@ key_do_init(void)
 	int i, error;
 
 	mutex_init(&key_mtx, MUTEX_DEFAULT, IPL_NONE);
+	key_psz = pserialize_create();
+	mutex_init(&key_sp_mtx, MUTEX_DEFAULT, IPL_NONE);
+	cv_init(&key_sp_cv, "key_sp");
 
 	pfkeystat_percpu = percpu_alloc(sizeof(uint64_t) * PFKEY_NSTATS);
 
@@ -7550,6 +7662,8 @@ key_do_init(void)
 		PSLIST_INIT(&sptree[i]);
 	}
 
+	PSLIST_INIT(&key_socksplist);
+
 	LIST_INIT(&sahtree);
 
 	for (i = 0; i <= SADB_SATYPE_MAX; i++) {
@@ -7565,11 +7679,13 @@ key_do_init(void)
 
 	/* system default */
 	ip4_def_policy.policy = IPSEC_POLICY_NONE;
-	ip4_def_policy.refcnt++;	/*never reclaim this*/
+	ip4_def_policy.state = IPSEC_SPSTATE_ALIVE;
+	localcount_init(&ip4_def_policy.localcount);
 
 #ifdef INET6
 	ip6_def_policy.policy = IPSEC_POLICY_NONE;
-	ip6_def_policy.refcnt++;	/*never reclaim this*/
+	ip6_def_policy.state = IPSEC_SPSTATE_ALIVE;
+	localcount_init(&ip6_def_policy.localcount);
 #endif
 
 	callout_reset(&key_timehandler_ch, hz, key_timehandler, NULL);
@@ -7909,10 +8025,12 @@ key_setspddump(int *errorp, pid_t pid)
 	u_int dir;
 	struct mbuf *m, *n;
 
+	KASSERT(mutex_owned(&key_sp_mtx));
+
 	/* search SPD entry and get buffer size. */
 	cnt = 0;
 	for (dir = 0; dir < IPSEC_DIR_MAX; dir++) {
-		SPLIST_READER_FOREACH(sp, dir) {
+		SPLIST_WRITER_FOREACH(sp, dir) {
 			cnt++;
 		}
 	}
@@ -7924,7 +8042,7 @@ key_setspddump(int *errorp, pid_t pid)
 
 	m = NULL;
 	for (dir = 0; dir < IPSEC_DIR_MAX; dir++) {
-		SPLIST_READER_FOREACH(sp, dir) {
+		SPLIST_WRITER_FOREACH(sp, dir) {
 			--cnt;
 			n = key_setdumpsp(sp, SADB_X_SPDDUMP, cnt, pid);
 
@@ -8029,16 +8147,16 @@ sysctl_net_key_dumpsp(SYSCTLFN_ARGS)
 	int err2 = 0;
 	char *p, *ep;
 	size_t len;
-	int s, error;
+	int error;
 
 	if (newp)
 		return (EPERM);
 	if (namelen != 0)
 		return (EINVAL);
 
-	s = splsoftnet();
+	mutex_enter(&key_sp_mtx);
 	m = key_setspddump(&error, l->l_proc->p_pid);
-	splx(s);
+	mutex_exit(&key_sp_mtx);
 	if (!m)
 		return (error);
 	if (!oldp)

Index: src/sys/netipsec/key.h
diff -u src/sys/netipsec/key.h:1.25 src/sys/netipsec/key.h:1.26
--- src/sys/netipsec/key.h:1.25	Wed Jul 26 03:59:59 2017
+++ src/sys/netipsec/key.h	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: key.h,v 1.25 2017/07/26 03:59:59 ozaki-r Exp $	*/
+/*	$NetBSD: key.h,v 1.26 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/key.h,v 1.1.4.1 2003/01/24 05:11:36 sam Exp $	*/
 /*	$KAME: key.h,v 1.21 2001/07/27 03:51:30 itojun Exp $	*/
 
@@ -55,11 +55,15 @@ struct secpolicy *key_gettunnel(const st
 	const struct sockaddr *, const struct sockaddr *,
 	const struct sockaddr *, const char*, int);
 /* NB: prepend with _ for KAME IPv6 compatbility */
-void _key_freesp(struct secpolicy **, const char*, int);
+void key_init_sp(struct secpolicy *);
+void key_free_sp(struct secpolicy *);
 u_int key_sp_refcnt(const struct secpolicy *);
 void key_sp_ref(struct secpolicy *, const char*, int);
+void key_sp_unref(struct secpolicy *, const char*, int);
 void key_sa_ref(struct secasvar *, const char*, int);
 
+void key_socksplist_add(struct secpolicy *);
+
 /*
  * Access to the SADB are interlocked with splsoftnet.  In particular,
  * holders of SA's use this to block accesses by protocol processing
@@ -73,8 +77,8 @@ void key_sa_ref(struct secasvar *, const
 	key_newsp(__func__, __LINE__)
 #define	KEY_GETTUNNEL(osrc, odst, isrc, idst)			\
 	key_gettunnel(osrc, odst, isrc, idst, __func__, __LINE__)
-#define	KEY_FREESP(spp)						\
-	_key_freesp(spp, __func__, __LINE__)
+#define	KEY_SP_UNREF(spp)					\
+	key_sp_unref(*(spp), __func__, __LINE__)
 #define	KEY_SP_REF(sp)						\
 	key_sp_ref(sp, __func__, __LINE__)
 #define KEY_SA_REF(sav)						\

Index: src/sys/netipsec/xform_ah.c
diff -u src/sys/netipsec/xform_ah.c:1.69 src/sys/netipsec/xform_ah.c:1.70
--- src/sys/netipsec/xform_ah.c:1.69	Thu Jul 27 06:59:28 2017
+++ src/sys/netipsec/xform_ah.c	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: xform_ah.c,v 1.69 2017/07/27 06:59:28 ozaki-r Exp $	*/
+/*	$NetBSD: xform_ah.c,v 1.70 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/xform_ah.c,v 1.1.4.1 2003/01/24 05:11:36 sam Exp $	*/
 /*	$OpenBSD: ip_ah.c,v 1.63 2001/06/26 06:18:58 angelos Exp $ */
 /*
@@ -39,7 +39,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: xform_ah.c,v 1.69 2017/07/27 06:59:28 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: xform_ah.c,v 1.70 2017/08/02 01:28:03 ozaki-r Exp $");
 
 #if defined(_KERNEL_OPT)
 #include "opt_inet.h"
@@ -54,6 +54,7 @@ __KERNEL_RCSID(0, "$NetBSD: xform_ah.c,v
 #include <sys/kernel.h>
 #include <sys/sysctl.h>
 #include <sys/pool.h>
+#include <sys/pserialize.h>
 
 #include <net/if.h>
 
@@ -1140,6 +1141,21 @@ ah_output(
 		goto bad;
 	}
 
+    {
+	int s = pserialize_read_enter();
+
+	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD)) {
+		pserialize_read_exit(s);
+		pool_put(&ah_tdb_crypto_pool, tc);
+		crypto_freereq(crp);
+		AH_STATINC(AH_STAT_NOTDB);
+		error = ENOENT;
+		goto bad;
+	}
+	KEY_SP_REF(isr->sp);
+	pserialize_read_exit(s);
+    }
+
 	/* Crypto operation descriptor. */
 	crp->crp_ilen = m->m_pkthdr.len; /* Total input length. */
 	crp->crp_flags = CRYPTO_F_IMBUF;
@@ -1150,7 +1166,6 @@ ah_output(
 
 	/* These are passed as-is to the callback. */
 	tc->tc_isr = isr;
-	KEY_SP_REF(isr->sp);
 	tc->tc_spi = sav->spi;
 	tc->tc_dst = sav->sah->saidx.dst;
 	tc->tc_proto = sav->sah->saidx.proto;
@@ -1255,13 +1270,13 @@ ah_output_cb(struct cryptop *crp)
 	/* NB: m is reclaimed by ipsec_process_done. */
 	err = ipsec_process_done(m, isr, sav);
 	KEY_FREESAV(&sav);
-	KEY_FREESP(&isr->sp);
+	KEY_SP_UNREF(&isr->sp);
 	IPSEC_RELEASE_GLOBAL_LOCKS();
 	return err;
 bad:
 	if (sav)
 		KEY_FREESAV(&sav);
-	KEY_FREESP(&isr->sp);
+	KEY_SP_UNREF(&isr->sp);
 	IPSEC_RELEASE_GLOBAL_LOCKS();
 	if (m)
 		m_freem(m);

Index: src/sys/netipsec/xform_esp.c
diff -u src/sys/netipsec/xform_esp.c:1.67 src/sys/netipsec/xform_esp.c:1.68
--- src/sys/netipsec/xform_esp.c:1.67	Thu Jul 27 06:59:28 2017
+++ src/sys/netipsec/xform_esp.c	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: xform_esp.c,v 1.67 2017/07/27 06:59:28 ozaki-r Exp $	*/
+/*	$NetBSD: xform_esp.c,v 1.68 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/xform_esp.c,v 1.2.2.1 2003/01/24 05:11:36 sam Exp $	*/
 /*	$OpenBSD: ip_esp.c,v 1.69 2001/06/26 06:18:59 angelos Exp $ */
 
@@ -39,7 +39,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: xform_esp.c,v 1.67 2017/07/27 06:59:28 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: xform_esp.c,v 1.68 2017/08/02 01:28:03 ozaki-r Exp $");
 
 #if defined(_KERNEL_OPT)
 #include "opt_inet.h"
@@ -55,6 +55,7 @@ __KERNEL_RCSID(0, "$NetBSD: xform_esp.c,
 #include <sys/sysctl.h>
 #include <sys/cprng.h>
 #include <sys/pool.h>
+#include <sys/pserialize.h>
 
 #include <net/if.h>
 
@@ -897,9 +898,23 @@ esp_output(
 		goto bad;
 	}
 
+    {
+	int s = pserialize_read_enter();
+
+	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD)) {
+		pserialize_read_exit(s);
+		pool_put(&esp_tdb_crypto_pool, tc);
+		crypto_freereq(crp);
+		ESP_STATINC(ESP_STAT_NOTDB);
+		error = ENOENT;
+		goto bad;
+	}
+	KEY_SP_REF(isr->sp);
+	pserialize_read_exit(s);
+    }
+
 	/* Callback parameters */
 	tc->tc_isr = isr;
-	KEY_SP_REF(isr->sp);
 	tc->tc_spi = sav->spi;
 	tc->tc_dst = saidx->dst;
 	tc->tc_proto = saidx->proto;
@@ -1032,13 +1047,13 @@ esp_output_cb(struct cryptop *crp)
 	/* NB: m is reclaimed by ipsec_process_done. */
 	err = ipsec_process_done(m, isr, sav);
 	KEY_FREESAV(&sav);
-	KEY_FREESP(&isr->sp);
+	KEY_SP_UNREF(&isr->sp);
 	IPSEC_RELEASE_GLOBAL_LOCKS();
 	return err;
 bad:
 	if (sav)
 		KEY_FREESAV(&sav);
-	KEY_FREESP(&isr->sp);
+	KEY_SP_UNREF(&isr->sp);
 	IPSEC_RELEASE_GLOBAL_LOCKS();
 	if (m)
 		m_freem(m);

Index: src/sys/netipsec/xform_ipcomp.c
diff -u src/sys/netipsec/xform_ipcomp.c:1.48 src/sys/netipsec/xform_ipcomp.c:1.49
--- src/sys/netipsec/xform_ipcomp.c:1.48	Thu Jul 27 06:59:28 2017
+++ src/sys/netipsec/xform_ipcomp.c	Wed Aug  2 01:28:03 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: xform_ipcomp.c,v 1.48 2017/07/27 06:59:28 ozaki-r Exp $	*/
+/*	$NetBSD: xform_ipcomp.c,v 1.49 2017/08/02 01:28:03 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/xform_ipcomp.c,v 1.1.4.1 2003/01/24 05:11:36 sam Exp $	*/
 /* $OpenBSD: ip_ipcomp.c,v 1.1 2001/07/05 12:08:52 jjbg Exp $ */
 
@@ -30,7 +30,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: xform_ipcomp.c,v 1.48 2017/07/27 06:59:28 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: xform_ipcomp.c,v 1.49 2017/08/02 01:28:03 ozaki-r Exp $");
 
 /* IP payload compression protocol (IPComp), see RFC 2393 */
 #if defined(_KERNEL_OPT)
@@ -45,6 +45,7 @@ __KERNEL_RCSID(0, "$NetBSD: xform_ipcomp
 #include <sys/protosw.h>
 #include <sys/sysctl.h>
 #include <sys/pool.h>
+#include <sys/pserialize.h>
 
 #include <netinet/in.h>
 #include <netinet/in_systm.h>
@@ -479,8 +480,22 @@ ipcomp_output(
 		goto bad;
 	}
 
-	tc->tc_isr = isr;
+    {
+	int s = pserialize_read_enter();
+
+	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD)) {
+		pserialize_read_exit(s);
+		pool_put(&ipcomp_tdb_crypto_pool, tc);
+		crypto_freereq(crp);
+		IPCOMP_STATINC(IPCOMP_STAT_NOTDB);
+		error = ENOENT;
+		goto bad;
+	}
 	KEY_SP_REF(isr->sp);
+	pserialize_read_exit(s);
+    }
+
+	tc->tc_isr = isr;
 	tc->tc_spi = sav->spi;
 	tc->tc_dst = sav->sah->saidx.dst;
 	tc->tc_proto = sav->sah->saidx.proto;
@@ -646,13 +661,13 @@ ipcomp_output_cb(struct cryptop *crp)
 	/* NB: m is reclaimed by ipsec_process_done. */
 	error = ipsec_process_done(m, isr, sav);
 	KEY_FREESAV(&sav);
-	KEY_FREESP(&isr->sp);
+	KEY_SP_UNREF(&isr->sp);
 	IPSEC_RELEASE_GLOBAL_LOCKS();
 	return error;
 bad:
 	if (sav)
 		KEY_FREESAV(&sav);
-	KEY_FREESP(&isr->sp);
+	KEY_SP_UNREF(&isr->sp);
 	IPSEC_RELEASE_GLOBAL_LOCKS();
 	if (m)
 		m_freem(m);

Index: src/sys/rump/librump/rumpnet/net_stub.c
diff -u src/sys/rump/librump/rumpnet/net_stub.c:1.26 src/sys/rump/librump/rumpnet/net_stub.c:1.27
--- src/sys/rump/librump/rumpnet/net_stub.c:1.26	Fri Apr 14 02:43:28 2017
+++ src/sys/rump/librump/rumpnet/net_stub.c	Wed Aug  2 01:28:02 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: net_stub.c,v 1.26 2017/04/14 02:43:28 ozaki-r Exp $	*/
+/*	$NetBSD: net_stub.c,v 1.27 2017/08/02 01:28:02 ozaki-r Exp $	*/
 
 /*
  * Copyright (c) 2008 Antti Kantee.  All Rights Reserved.
@@ -26,7 +26,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: net_stub.c,v 1.26 2017/04/14 02:43:28 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: net_stub.c,v 1.27 2017/08/02 01:28:02 ozaki-r Exp $");
 
 #include <sys/mutex.h>
 #include <sys/param.h>
@@ -108,7 +108,7 @@ __weak_alias(ipsec_init_policy,rumpnet_s
 __weak_alias(ipsec_pcbconn,rumpnet_stub);
 __weak_alias(ipsec_pcbdisconn,rumpnet_stub);
 __weak_alias(key_sa_routechange,rumpnet_stub);
-__weak_alias(_key_freesp,rumpnet_stub);
+__weak_alias(key_sp_unref,rumpnet_stub);
 
 struct ifnet_head ifnet_list;
 struct pslist_head ifnet_pslist;

Reply via email to