Module Name:    src
Committed By:   ozaki-r
Date:           Wed Aug  9 09:48:11 UTC 2017

Modified Files:
        src/sys/netipsec: key.c key.h keydb.h xform_ah.c xform_esp.c
            xform_ipcomp.c

Log Message:
MP-ify SAD (savlist)

localcount(9) is used to protect savlist of sah. The basic design is
similar to MP-ifications of SPD and SAD sahlist. Please read the
locking notes of SAD for more details.


To generate a diff of this commit:
cvs rdiff -u -r1.222 -r1.223 src/sys/netipsec/key.c
cvs rdiff -u -r1.28 -r1.29 src/sys/netipsec/key.h
cvs rdiff -u -r1.19 -r1.20 src/sys/netipsec/keydb.h
cvs rdiff -u -r1.71 -r1.72 src/sys/netipsec/xform_ah.c
cvs rdiff -u -r1.69 -r1.70 src/sys/netipsec/xform_esp.c
cvs rdiff -u -r1.50 -r1.51 src/sys/netipsec/xform_ipcomp.c

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/sys/netipsec/key.c
diff -u src/sys/netipsec/key.c:1.222 src/sys/netipsec/key.c:1.223
--- src/sys/netipsec/key.c:1.222	Wed Aug  9 08:30:54 2017
+++ src/sys/netipsec/key.c	Wed Aug  9 09:48:11 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: key.c,v 1.222 2017/08/09 08:30:54 ozaki-r Exp $	*/
+/*	$NetBSD: key.c,v 1.223 2017/08/09 09:48:11 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/key.c,v 1.3.2.3 2004/02/14 22:23:23 bms Exp $	*/
 /*	$KAME: key.c,v 1.191 2001/06/27 10:46:49 sakane Exp $	*/
 
@@ -32,7 +32,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: key.c,v 1.222 2017/08/09 08:30:54 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: key.c,v 1.223 2017/08/09 09:48:11 ozaki-r Exp $");
 
 /*
  * This code is referd to RFC 2367
@@ -202,13 +202,15 @@ static u_int32_t acq_seq = 0;
  * - Data structures
  *   - SAs are managed by the list called key_sad.sahlist and sav lists of sah
  *     entries
+ *     - An sav is supposed to be an SA from a viewpoint of users
  *   - A sah has sav lists for each SA state
  *   - Multiple sahs with the same saidx can exist
  *     - Only one entry has MATURE state and others should be DEAD
  *     - DEAD entries are just ignored from searching
- * - Modifications to the key_sad.sahlist must be done with holding key_sad.lock
- *   which is a adaptive mutex
- * - Read accesses to the key_sad.sahlist must be in pserialize(9) read sections
+ * - Modifications to the key_sad.sahlist and sah.savlist must be done with
+ *   holding key_sad.lock which is a adaptive mutex
+ * - Read accesses to the key_sad.sahlist and sah.savlist must be in
+ *   pserialize(9) read sections
  * - sah's lifetime is managed by localcount(9)
  * - Getting an sah entry
  *   - We get an sah from the key_sad.sahlist
@@ -218,6 +220,16 @@ static u_int32_t acq_seq = 0;
  * - An sah is destroyed when its state become DEAD and no sav is
  *   listed to the sah
  *   - The destruction is done only in the timer (see key_timehandler_sad)
+ * - sav's lifetime is managed by localcount(9)
+ * - Getting an sav entry
+ *   - First get an sah by saidx and get an sav from either of sah's savlists
+ *     - Must iterate the list and increment the reference count of a found sav
+ *       (by key_sa_ref) in a pserialize read section
+ *   - We can gain another reference from a held SA only if we check its state
+ *     and take its reference in a pserialize read section
+ *     (see esp_output for example)
+ *   - A gotten sav must be released after use by key_sa_unref
+ * - An sav is destroyed when its state become DEAD
  */
 /*
  * Locking notes on misc data:
@@ -643,6 +655,9 @@ static void key_destroy_sah(struct secas
 static bool key_sah_has_sav(struct secashead *);
 static void key_sah_ref(struct secashead *);
 static void key_sah_unref(struct secashead *);
+static void key_init_sav(struct secasvar *);
+static void key_destroy_sav(struct secasvar *);
+static void key_destroy_sav_with_ref(struct secasvar *);
 static struct secasvar *key_newsav(struct mbuf *,
 	const struct sadb_msghdr *, int *, const char*, int);
 #define	KEY_NEWSAV(m, sadb, e)				\
@@ -776,35 +791,6 @@ static struct callout	key_timehandler_ch
 static struct workqueue	*key_timehandler_wq;
 static struct work	key_timehandler_wk;
 
-#ifdef IPSEC_REF_DEBUG
-#define REFLOG(label, p, where, tag)					\
-	log(LOG_DEBUG, "%s:%d: " label " : refcnt=%d (%p)\n.",		\
-	    (where), (tag), (p)->refcnt, (p))
-#else
-#define REFLOG(label, p, where, tag)	do {} while (0)
-#endif
-
-#define	SA_ADDREF(p) do {						\
-	atomic_inc_uint(&(p)->refcnt);					\
-	REFLOG("SA_ADDREF", (p), __func__, __LINE__);			\
-	KASSERTMSG((p)->refcnt != 0, "SA refcnt overflow");		\
-} while (0)
-#define	SA_ADDREF2(p, where, tag) do {					\
-	atomic_inc_uint(&(p)->refcnt);					\
-	REFLOG("SA_ADDREF", (p), (where), (tag));			\
-	KASSERTMSG((p)->refcnt != 0, "SA refcnt overflow");		\
-} while (0)
-#define	SA_DELREF(p) do {						\
-	KASSERTMSG((p)->refcnt > 0, "SA refcnt underflow");		\
-	atomic_dec_uint(&(p)->refcnt);					\
-	REFLOG("SA_DELREF", (p), __func__, __LINE__);			\
-} while (0)
-#define	SA_DELREF2(p, nv, where, tag) do {				\
-	KASSERTMSG((p)->refcnt > 0, "SA refcnt underflow");		\
-	nv = atomic_dec_uint_nv(&(p)->refcnt);				\
-	REFLOG("SA_DELREF", (p), (where), (tag));			\
-} while (0)
-
 u_int
 key_sp_refcnt(const struct secpolicy *sp)
 {
@@ -1085,7 +1071,7 @@ key_lookup_sa_bysaidx(const struct secas
 			sav = last;
 		}
 		if (sav != NULL) {
-			SA_ADDREF(sav);
+			KEY_SA_REF(sav);
 			KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
 			    "DP cause refcnt++:%d SA:%p\n",
 			    key_sa_refcnt(sav), sav);
@@ -1284,7 +1270,7 @@ key_lookup_sa(
 				/* check dst address */
 				if (!key_sockaddr_match(&dst->sa, &sav->sah->saidx.dst.sa, chkport))
 					continue;
-				SA_ADDREF2(sav, where, tag);
+				key_sa_ref(sav, where, tag);
 				goto done;
 			}
 		}
@@ -1371,25 +1357,46 @@ key_sp_unref(struct secpolicy *sp, const
 	localcount_release(&sp->localcount, &key_spd.cv, &key_spd.lock);
 }
 
+static void
+key_init_sav(struct secasvar *sav)
+{
+
+	ASSERT_SLEEPABLE();
+
+	localcount_init(&sav->localcount);
+	SAVLIST_ENTRY_INIT(sav);
+}
+
 u_int
 key_sa_refcnt(const struct secasvar *sav)
 {
 
-	if (sav == NULL)
-		return 0;
-
-	return sav->refcnt;
+	/* FIXME */
+	return 0;
 }
 
 void
 key_sa_ref(struct secasvar *sav, const char* where, int tag)
 {
 
-	SA_ADDREF2(sav, where, tag);
+	localcount_acquire(&sav->localcount);
 
 	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
-	    "DP cause refcnt++:%d SA:%p from %s:%u\n",
-	    sav->refcnt, sav, where, tag);
+	    "DP cause refcnt++: SA:%p from %s:%u\n",
+	    sav, where, tag);
+}
+
+void
+key_sa_unref(struct secasvar *sav, const char* where, int tag)
+{
+
+	KDASSERT(mutex_ownable(&key_sad.lock));
+
+	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
+	    "DP cause refcnt--: SA:%p from %s:%u\n",
+	    sav, where, tag);
+
+	localcount_release(&sav->localcount, &key_sad.cv, &key_sad.lock);
 }
 
 #if 0
@@ -1468,32 +1475,67 @@ key_freesp_so(struct secpolicy **sp)
 #endif
 
 /*
- * Must be called after calling key_lookup_sa().
- * This function is called by key_freesp() to free some SA allocated
- * for a policy.
+ * Remove the sav from the savlist of its sah and wait for references to the sav
+ * to be released. key_sad.lock must be held.
  */
-void
-key_freesav(struct secasvar **psav, const char* where, int tag)
+static void
+key_unlink_sav(struct secasvar *sav)
 {
-	struct secasvar *sav = *psav;
-	unsigned int nv;
 
-	KASSERT(sav != NULL);
+	KASSERT(mutex_owned(&key_sad.lock));
 
-	SA_DELREF2(sav, nv, where, tag);
+	SAVLIST_WRITER_REMOVE(sav);
 
-	KEYDEBUG_PRINTF(KEYDEBUG_IPSEC_STAMP,
-	    "DP SA:%p (SPI %lu) from %s:%u; refcnt now %u\n",
-	    sav, (u_long)ntohl(sav->spi), where, tag, nv);
+#ifdef NET_MPSAFE
+	KASSERT(mutex_ownable(softnet_lock));
+	pserialize_perform(key_sad_psz);
+#endif
 
-	if (nv == 0) {
-		*psav = NULL;
+	localcount_drain(&sav->localcount, &key_sad.cv, &key_sad.lock);
+}
 
-		/* remove from SA header */
-		SAVLIST_WRITER_REMOVE(sav);
+/*
+ * Destroy an sav where the sav must be unlinked from an sah
+ * by say key_unlink_sav.
+ */
+static void
+key_destroy_sav(struct secasvar *sav)
+{
 
-		key_delsav(sav);
-	}
+	ASSERT_SLEEPABLE();
+
+	localcount_fini(&sav->localcount);
+	SAVLIST_ENTRY_DESTROY(sav);
+
+	key_delsav(sav);
+}
+
+/*
+ * Destroy sav with holding its reference.
+ */
+static void
+key_destroy_sav_with_ref(struct secasvar *sav)
+{
+
+	ASSERT_SLEEPABLE();
+
+	mutex_enter(&key_sad.lock);
+	sav->state = SADB_SASTATE_DEAD;
+	SAVLIST_WRITER_REMOVE(sav);
+	mutex_exit(&key_sad.lock);
+
+	/* We cannot unref with holding key_sad.lock */
+	KEY_SA_UNREF(&sav);
+
+	mutex_enter(&key_sad.lock);
+#ifdef NET_MPSAFE
+	KASSERT(mutex_ownable(softnet_lock));
+	pserialize_perform(key_sad_psz);
+#endif
+	localcount_drain(&sav->localcount, &key_sad.cv, &key_sad.lock);
+	mutex_exit(&key_sad.lock);
+
+	key_destroy_sav(sav);
 }
 
 /* %%% SPD management */
@@ -3173,15 +3215,9 @@ static void
 key_delsav(struct secasvar *sav)
 {
 
-	KASSERT(sav != NULL);
-	KASSERTMSG(sav->refcnt == 0, "reference count %u > 0", sav->refcnt);
-
 	key_clear_xform(sav);
 	key_freesaval(sav);
-	SAVLIST_ENTRY_DESTROY(sav);
-	kmem_intr_free(sav, sizeof(*sav));
-
-	return;
+	kmem_free(sav, sizeof(*sav));
 }
 
 /*
@@ -3315,7 +3351,7 @@ key_getsavbyspi(struct secashead *sah, u
 			}
 
 			if (sav->spi == spi) {
-				SA_ADDREF(sav);
+				KEY_SA_REF(sav);
 				goto out;
 			}
 		}
@@ -4718,20 +4754,23 @@ restart:
 		pserialize_read_exit(s);
 
 		/* if LARVAL entry doesn't become MATURE, delete it. */
+		mutex_enter(&key_sad.lock);
 	restart_sav_LARVAL:
-		SAVLIST_READER_FOREACH(sav, sah, SADB_SASTATE_LARVAL) {
+		SAVLIST_WRITER_FOREACH(sav, sah, SADB_SASTATE_LARVAL) {
 			if (now - sav->created > key_larval_lifetime) {
 				key_sa_chgstate(sav, SADB_SASTATE_DEAD);
 				goto restart_sav_LARVAL;
 			}
 		}
+		mutex_exit(&key_sad.lock);
 
 		/*
 		 * check MATURE entry to start to send expire message
 		 * whether or not.
 		 */
 	restart_sav_MATURE:
-		SAVLIST_READER_FOREACH(sav, sah, SADB_SASTATE_MATURE) {
+		mutex_enter(&key_sad.lock);
+		SAVLIST_WRITER_FOREACH(sav, sah, SADB_SASTATE_MATURE) {
 			/* we don't need to check. */
 			if (sav->lft_s == NULL)
 				continue;
@@ -4748,8 +4787,10 @@ restart:
 				 */
 				if (sav->lft_c->sadb_lifetime_usetime == 0) {
 					key_sa_chgstate(sav, SADB_SASTATE_DEAD);
+					mutex_exit(&key_sad.lock);
 				} else {
 					key_sa_chgstate(sav, SADB_SASTATE_DYING);
+					mutex_exit(&key_sad.lock);
 					/*
 					 * XXX If we keep to send expire
 					 * message in the status of
@@ -4770,6 +4811,7 @@ restart:
 			         sav->lft_c->sadb_lifetime_bytes) {
 
 				key_sa_chgstate(sav, SADB_SASTATE_DYING);
+				mutex_exit(&key_sad.lock);
 				/*
 				 * XXX If we keep to send expire
 				 * message in the status of
@@ -4779,10 +4821,12 @@ restart:
 				goto restart_sav_MATURE;
 			}
 		}
+		mutex_exit(&key_sad.lock);
 
 		/* check DYING entry to change status to DEAD. */
+		mutex_enter(&key_sad.lock);
 	restart_sav_DYING:
-		SAVLIST_READER_FOREACH(sav, sah, SADB_SASTATE_DYING) {
+		SAVLIST_WRITER_FOREACH(sav, sah, SADB_SASTATE_DYING) {
 			/* we don't need to check. */
 			if (sav->lft_h == NULL)
 				continue;
@@ -4819,13 +4863,18 @@ restart:
 				goto restart_sav_DYING;
 			}
 		}
+		mutex_exit(&key_sad.lock);
 
 		/* delete entry in DEAD */
 	restart_sav_DEAD:
-		SAVLIST_READER_FOREACH(sav, sah, SADB_SASTATE_DEAD) {
-			KEY_FREESAV(&sav);
+		mutex_enter(&key_sad.lock);
+		SAVLIST_WRITER_FOREACH(sav, sah, SADB_SASTATE_DEAD) {
+			key_unlink_sav(sav);
+			mutex_exit(&key_sad.lock);
+			key_destroy_sav(sav);
 			goto restart_sav_DEAD;
 		}
+		mutex_exit(&key_sad.lock);
 
 		s = pserialize_read_enter();
 		key_sah_unref(sah);
@@ -5120,11 +5169,10 @@ key_api_getspi(struct socket *so, struct
 	/* set spi */
 	newsav->spi = htonl(spi);
 
-	/* add to satree */
-	newsav->refcnt = 1;
+	/* Add to sah#savlist */
+	key_init_sav(newsav);
 	newsav->sah = sah;
 	newsav->state = SADB_SASTATE_LARVAL;
-	SAVLIST_ENTRY_INIT(newsav);
 	mutex_enter(&key_sad.lock);
 	SAVLIST_WRITER_INSERT_TAIL(sah, SADB_SASTATE_LARVAL, newsav);
 	mutex_exit(&key_sad.lock);
@@ -5567,22 +5615,20 @@ key_api_update(struct socket *so, struct
 		goto error;
 	}
 
-	/* add to satree */
-	newsav->refcnt = 1;
+	/* Add to sah#savlist */
+	key_init_sav(newsav);
 	newsav->state = SADB_SASTATE_MATURE;
-	SAVLIST_ENTRY_INIT(newsav);
 	mutex_enter(&key_sad.lock);
 	SAVLIST_WRITER_INSERT_TAIL(sah, SADB_SASTATE_MATURE, newsav);
 	mutex_exit(&key_sad.lock);
 	key_validate_savlist(sah, SADB_SASTATE_MATURE);
 
-	key_sa_chgstate(sav, SADB_SASTATE_DEAD);
-	KEY_FREESAV(&sav);
-	KEY_FREESAV(&sav);
-
 	key_sah_unref(sah);
 	sah = NULL;
 
+	key_destroy_sav_with_ref(sav);
+	sav = NULL;
+
     {
 	struct mbuf *n;
 
@@ -5765,10 +5811,9 @@ key_api_add(struct socket *so, struct mb
 		goto error;
 	}
 
-	/* add to satree */
-	newsav->refcnt = 1;
+	/* Add to sah#savlist */
+	key_init_sav(newsav);
 	newsav->state = SADB_SASTATE_MATURE;
-	SAVLIST_ENTRY_INIT(newsav);
 	mutex_enter(&key_sad.lock);
 	SAVLIST_WRITER_INSERT_TAIL(sah, SADB_SASTATE_MATURE, newsav);
 	mutex_exit(&key_sad.lock);
@@ -5995,9 +6040,8 @@ key_api_delete(struct socket *so, struct
 		return key_senderror(so, m, ENOENT);
 	}
 
-	key_sa_chgstate(sav, SADB_SASTATE_DEAD);
-	KEY_FREESAV(&sav);
-	KEY_FREESAV(&sav);
+	key_destroy_sav_with_ref(sav);
+	sav = NULL;
 
     {
 	struct mbuf *n;
@@ -6049,20 +6093,15 @@ key_delete_all(struct socket *so, struct
 			if (state == SADB_SASTATE_LARVAL)
 				continue;
 		restart:
+			mutex_enter(&key_sad.lock);
 			SAVLIST_WRITER_FOREACH(sav, sah, state) {
-				/* sanity check */
-				if (sav->state != state) {
-					IPSECLOG(LOG_DEBUG,
-					    "invalid sav->state "
-					    "(queue: %d SA: %d)\n",
-					    state, sav->state);
-					continue;
-				}
-
-				key_sa_chgstate(sav, SADB_SASTATE_DEAD);
-				KEY_FREESAV(&sav);
+				sav->state = SADB_SASTATE_DEAD;
+				key_unlink_sav(sav);
+				mutex_exit(&key_sad.lock);
+				key_destroy_sav(sav);
 				goto restart;
 			}
+			mutex_exit(&key_sad.lock);
 		}
 		key_sah_unref(sah);
 	}
@@ -7292,11 +7331,15 @@ key_api_flush(struct socket *so, struct 
 
 		SASTATE_ALIVE_FOREACH(state) {
 		restart:
+			mutex_enter(&key_sad.lock);
 			SAVLIST_WRITER_FOREACH(sav, sah, state) {
-				key_sa_chgstate(sav, SADB_SASTATE_DEAD);
-				KEY_FREESAV(&sav);
+				sav->state = SADB_SASTATE_DEAD;
+				key_unlink_sav(sav);
+				mutex_exit(&key_sad.lock);
+				key_destroy_sav(sav);
 				goto restart;
 			}
+			mutex_exit(&key_sad.lock);
 		}
 
 		s = pserialize_read_enter();
@@ -8241,14 +8284,15 @@ key_sa_chgstate(struct secasvar *sav, u_
 {
 	struct secasvar *_sav;
 
-	KASSERT(sav != NULL);
+	ASSERT_SLEEPABLE();
+	KASSERT(mutex_owned(&key_sad.lock));
 
 	if (sav->state == state)
 		return;
 
-	SAVLIST_WRITER_REMOVE(sav);
+	key_unlink_sav(sav);
 	SAVLIST_ENTRY_DESTROY(sav);
-	SAVLIST_ENTRY_INIT(sav);
+	key_init_sav(sav);
 
 	sav->state = state;
 	if (!SADB_SASTATE_USABLE_P(sav)) {

Index: src/sys/netipsec/key.h
diff -u src/sys/netipsec/key.h:1.28 src/sys/netipsec/key.h:1.29
--- src/sys/netipsec/key.h:1.28	Tue Aug  8 08:23:10 2017
+++ src/sys/netipsec/key.h	Wed Aug  9 09:48:11 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: key.h,v 1.28 2017/08/08 08:23:10 ozaki-r Exp $	*/
+/*	$NetBSD: key.h,v 1.29 2017/08/09 09:48:11 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/key.h,v 1.1.4.1 2003/01/24 05:11:36 sam Exp $	*/
 /*	$KAME: key.h,v 1.21 2001/07/27 03:51:30 itojun Exp $	*/
 
@@ -61,6 +61,7 @@ u_int key_sp_refcnt(const struct secpoli
 void key_sp_ref(struct secpolicy *, const char*, int);
 void key_sp_unref(struct secpolicy *, const char*, int);
 void key_sa_ref(struct secasvar *, const char*, int);
+void key_sa_unref(struct secasvar *, const char*, int);
 u_int key_sa_refcnt(const struct secasvar *);
 
 void key_socksplist_add(struct secpolicy *);
@@ -85,7 +86,7 @@ void key_socksplist_add(struct secpolicy
 #define KEY_SA_REF(sav)						\
 	key_sa_ref(sav, __func__, __LINE__)
 #define	KEY_SA_UNREF(psav)					\
-	key_freesav(psav, __func__, __LINE__)
+	key_sa_unref(*(psav), __func__, __LINE__)
 
 struct secasvar *key_lookup_sa(const union sockaddr_union *,
 		u_int, u_int32_t, u_int16_t, u_int16_t, const char*, int);
@@ -93,8 +94,6 @@ void key_freesav(struct secasvar **, con
 
 #define	KEY_LOOKUP_SA(dst, proto, spi, sport, dport)		\
 	key_lookup_sa(dst, proto, spi, sport, dport,  __func__, __LINE__)
-#define	KEY_FREESAV(psav)					\
-	key_freesav(psav, __func__, __LINE__)
 
 int key_checktunnelsanity (struct secasvar *, u_int, void *, void *);
 int key_checkrequest(struct ipsecrequest *, struct secasvar **);

Index: src/sys/netipsec/keydb.h
diff -u src/sys/netipsec/keydb.h:1.19 src/sys/netipsec/keydb.h:1.20
--- src/sys/netipsec/keydb.h:1.19	Tue Aug  8 04:17:34 2017
+++ src/sys/netipsec/keydb.h	Wed Aug  9 09:48:11 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: keydb.h,v 1.19 2017/08/08 04:17:34 ozaki-r Exp $	*/
+/*	$NetBSD: keydb.h,v 1.20 2017/08/09 09:48:11 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/keydb.h,v 1.1.4.1 2003/01/24 05:11:36 sam Exp $	*/
 /*	$KAME: keydb.h,v 1.14 2000/08/02 17:58:26 sakane Exp $	*/
 
@@ -94,8 +94,8 @@ struct comp_algo;
 /* Security Association */
 struct secasvar {
 	struct pslist_entry pslist_entry;
+	struct localcount localcount;	/* reference count */
 
-	u_int refcnt;			/* reference count */
 	u_int8_t state;			/* Status of this Association */
 
 	u_int8_t alg_auth;		/* Authentication Algorithm Identifier*/

Index: src/sys/netipsec/xform_ah.c
diff -u src/sys/netipsec/xform_ah.c:1.71 src/sys/netipsec/xform_ah.c:1.72
--- src/sys/netipsec/xform_ah.c:1.71	Thu Aug  3 06:32:51 2017
+++ src/sys/netipsec/xform_ah.c	Wed Aug  9 09:48:11 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: xform_ah.c,v 1.71 2017/08/03 06:32:51 ozaki-r Exp $	*/
+/*	$NetBSD: xform_ah.c,v 1.72 2017/08/09 09:48:11 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/xform_ah.c,v 1.1.4.1 2003/01/24 05:11:36 sam Exp $	*/
 /*	$OpenBSD: ip_ah.c,v 1.63 2001/06/26 06:18:58 angelos Exp $ */
 /*
@@ -39,7 +39,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: xform_ah.c,v 1.71 2017/08/03 06:32:51 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: xform_ah.c,v 1.72 2017/08/09 09:48:11 ozaki-r Exp $");
 
 #if defined(_KERNEL_OPT)
 #include "opt_inet.h"
@@ -733,6 +733,22 @@ ah_input(struct mbuf *m, struct secasvar
 		goto bad;
 	}
 
+    {
+	int s = pserialize_read_enter();
+
+	/*
+	 * Take another reference to the SA for opencrypto callback.
+	 */
+	if (__predict_false(sav->state == SADB_SASTATE_DEAD)) {
+		pserialize_read_exit(s);
+		stat = AH_STAT_NOTDB;
+		error = ENOENT;
+		goto bad;
+	}
+	KEY_SA_REF(sav);
+	pserialize_read_exit(s);
+    }
+
 	/* Crypto operation descriptor. */
 	crp->crp_ilen = m->m_pkthdr.len; /* Total input length. */
 	crp->crp_flags = CRYPTO_F_IMBUF;
@@ -749,7 +765,6 @@ ah_input(struct mbuf *m, struct secasvar
 	tc->tc_protoff = protoff;
 	tc->tc_skip = skip;
 	tc->tc_sav = sav;
-	KEY_SA_REF(sav);
 
 	DPRINTF(("%s: hash over %d bytes, skip %d: "
 		 "crda len %d skip %d inject %d\n", __func__,
@@ -1144,7 +1159,11 @@ ah_output(
     {
 	int s = pserialize_read_enter();
 
-	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD)) {
+	/*
+	 * Take another reference to the SP and the SA for opencrypto callback.
+	 */
+	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD ||
+	    sav->state == SADB_SASTATE_DEAD)) {
 		pserialize_read_exit(s);
 		pool_put(&ah_tdb_crypto_pool, tc);
 		crypto_freereq(crp);
@@ -1153,6 +1172,7 @@ ah_output(
 		goto bad;
 	}
 	KEY_SP_REF(isr->sp);
+	KEY_SA_REF(sav);
 	pserialize_read_exit(s);
     }
 
@@ -1172,7 +1192,6 @@ ah_output(
 	tc->tc_skip = skip;
 	tc->tc_protoff = protoff;
 	tc->tc_sav = sav;
-	KEY_SA_REF(sav);
 
 	return crypto_dispatch(crp);
 bad:

Index: src/sys/netipsec/xform_esp.c
diff -u src/sys/netipsec/xform_esp.c:1.69 src/sys/netipsec/xform_esp.c:1.70
--- src/sys/netipsec/xform_esp.c:1.69	Thu Aug  3 06:32:51 2017
+++ src/sys/netipsec/xform_esp.c	Wed Aug  9 09:48:11 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: xform_esp.c,v 1.69 2017/08/03 06:32:51 ozaki-r Exp $	*/
+/*	$NetBSD: xform_esp.c,v 1.70 2017/08/09 09:48:11 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/xform_esp.c,v 1.2.2.1 2003/01/24 05:11:36 sam Exp $	*/
 /*	$OpenBSD: ip_esp.c,v 1.69 2001/06/26 06:18:59 angelos Exp $ */
 
@@ -39,7 +39,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: xform_esp.c,v 1.69 2017/08/03 06:32:51 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: xform_esp.c,v 1.70 2017/08/09 09:48:11 ozaki-r Exp $");
 
 #if defined(_KERNEL_OPT)
 #include "opt_inet.h"
@@ -429,6 +429,23 @@ esp_input(struct mbuf *m, struct secasva
 		crde = crp->crp_desc;
 	}
 
+    {
+	int s = pserialize_read_enter();
+
+	/*
+	 * Take another reference to the SA for opencrypto callback.
+	 */
+	if (__predict_false(sav->state == SADB_SASTATE_DEAD)) {
+		pserialize_read_exit(s);
+		pool_put(&esp_tdb_crypto_pool, tc);
+		crypto_freereq(crp);
+		ESP_STATINC(ESP_STAT_NOTDB);
+		return ENOENT;
+	}
+	KEY_SA_REF(sav);
+	pserialize_read_exit(s);
+    }
+
 	/* Crypto operation descriptor */
 	crp->crp_ilen = m->m_pkthdr.len; /* Total input length */
 	crp->crp_flags = CRYPTO_F_IMBUF;
@@ -444,7 +461,6 @@ esp_input(struct mbuf *m, struct secasva
 	tc->tc_protoff = protoff;
 	tc->tc_skip = skip;
 	tc->tc_sav = sav;
-	KEY_SA_REF(sav);
 
 	/* Decryption descriptor */
 	if (espx) {
@@ -901,7 +917,11 @@ esp_output(
     {
 	int s = pserialize_read_enter();
 
-	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD)) {
+	/*
+	 * Take another reference to the SP and the SA for opencrypto callback.
+	 */
+	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD ||
+	    sav->state == SADB_SASTATE_DEAD)) {
 		pserialize_read_exit(s);
 		pool_put(&esp_tdb_crypto_pool, tc);
 		crypto_freereq(crp);
@@ -910,6 +930,7 @@ esp_output(
 		goto bad;
 	}
 	KEY_SP_REF(isr->sp);
+	KEY_SA_REF(sav);
 	pserialize_read_exit(s);
     }
 
@@ -919,7 +940,6 @@ esp_output(
 	tc->tc_dst = saidx->dst;
 	tc->tc_proto = saidx->proto;
 	tc->tc_sav = sav;
-	KEY_SA_REF(sav);
 
 	/* Crypto operation descriptor. */
 	crp->crp_ilen = m->m_pkthdr.len; /* Total input length. */

Index: src/sys/netipsec/xform_ipcomp.c
diff -u src/sys/netipsec/xform_ipcomp.c:1.50 src/sys/netipsec/xform_ipcomp.c:1.51
--- src/sys/netipsec/xform_ipcomp.c:1.50	Thu Aug  3 06:32:51 2017
+++ src/sys/netipsec/xform_ipcomp.c	Wed Aug  9 09:48:11 2017
@@ -1,4 +1,4 @@
-/*	$NetBSD: xform_ipcomp.c,v 1.50 2017/08/03 06:32:51 ozaki-r Exp $	*/
+/*	$NetBSD: xform_ipcomp.c,v 1.51 2017/08/09 09:48:11 ozaki-r Exp $	*/
 /*	$FreeBSD: src/sys/netipsec/xform_ipcomp.c,v 1.1.4.1 2003/01/24 05:11:36 sam Exp $	*/
 /* $OpenBSD: ip_ipcomp.c,v 1.1 2001/07/05 12:08:52 jjbg Exp $ */
 
@@ -30,7 +30,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: xform_ipcomp.c,v 1.50 2017/08/03 06:32:51 ozaki-r Exp $");
+__KERNEL_RCSID(0, "$NetBSD: xform_ipcomp.c,v 1.51 2017/08/09 09:48:11 ozaki-r Exp $");
 
 /* IP payload compression protocol (IPComp), see RFC 2393 */
 #if defined(_KERNEL_OPT)
@@ -184,6 +184,23 @@ ipcomp_input(struct mbuf *m, struct seca
 		return error;
 	}
 
+    {
+	int s = pserialize_read_enter();
+
+	/*
+	 * Take another reference to the SA for opencrypto callback.
+	 */
+	if (__predict_false(sav->state == SADB_SASTATE_DEAD)) {
+		pserialize_read_exit(s);
+		pool_put(&ipcomp_tdb_crypto_pool, tc);
+		crypto_freereq(crp);
+		IPCOMP_STATINC(IPCOMP_STAT_NOTDB);
+		return ENOENT;
+	}
+	KEY_SA_REF(sav);
+	pserialize_read_exit(s);
+    }
+
 	crdc = crp->crp_desc;
 
 	crdc->crd_skip = skip + hlen;
@@ -209,7 +226,6 @@ ipcomp_input(struct mbuf *m, struct seca
 	tc->tc_protoff = protoff;
 	tc->tc_skip = skip;
 	tc->tc_sav = sav;
-	KEY_SA_REF(sav);
 
 	return crypto_dispatch(crp);
 }
@@ -483,7 +499,11 @@ ipcomp_output(
     {
 	int s = pserialize_read_enter();
 
-	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD)) {
+	/*
+	 * Take another reference to the SP and the SA for opencrypto callback.
+	 */
+	if (__predict_false(isr->sp->state == IPSEC_SPSTATE_DEAD ||
+	    sav->state == SADB_SASTATE_DEAD)) {
 		pserialize_read_exit(s);
 		pool_put(&ipcomp_tdb_crypto_pool, tc);
 		crypto_freereq(crp);
@@ -492,6 +512,7 @@ ipcomp_output(
 		goto bad;
 	}
 	KEY_SP_REF(isr->sp);
+	KEY_SA_REF(sav);
 	pserialize_read_exit(s);
     }
 
@@ -502,7 +523,6 @@ ipcomp_output(
 	tc->tc_skip = skip;
 	tc->tc_protoff = protoff;
 	tc->tc_sav = sav;
-	KEY_SA_REF(sav);
 
 	/* Crypto operation descriptor */
 	crp->crp_ilen = m->m_pkthdr.len;	/* Total input length */

Reply via email to