Always lock chain when accessing filter_chain list, instead of relying on rtnl lock. Dereference filter_chain with tcf_chain_dereference() lockdep macro to verify that all users of chain_list have the lock taken.
Rearrange tp insert/remove code in tc_new_tfilter/tc_del_tfilter to execute all necessary code while holding chain lock in order to prevent invalidation of chain_info structure by potential concurrent change. Signed-off-by: Vlad Buslov <vla...@mellanox.com> Acked-by: Jiri Pirko <j...@mellanox.com> --- include/net/sch_generic.h | 17 ++++++++ net/sched/cls_api.c | 103 +++++++++++++++++++++++++++++----------------- 2 files changed, 83 insertions(+), 37 deletions(-) diff --git a/include/net/sch_generic.h b/include/net/sch_generic.h index 5f4fc28fc77a..82fa23da4969 100644 --- a/include/net/sch_generic.h +++ b/include/net/sch_generic.h @@ -338,6 +338,8 @@ struct qdisc_skb_cb { typedef void tcf_chain_head_change_t(struct tcf_proto *tp_head, void *priv); struct tcf_chain { + /* Protects filter_chain and filter_chain_list. */ + spinlock_t filter_chain_lock; struct tcf_proto __rcu *filter_chain; struct list_head list; struct tcf_block *block; @@ -371,6 +373,21 @@ struct tcf_block { struct rcu_head rcu; }; +#ifdef CONFIG_PROVE_LOCKING +static inline bool lockdep_tcf_chain_is_locked(struct tcf_chain *chain) +{ + return lockdep_is_held(&chain->filter_chain_lock); +} +#else +static inline bool lockdep_tcf_chain_is_locked(struct tcf_block *chain) +{ + return true; +} +#endif /* #ifdef CONFIG_PROVE_LOCKING */ + +#define tcf_chain_dereference(p, chain) \ + rcu_dereference_protected(p, lockdep_tcf_chain_is_locked(chain)) + static inline void tcf_block_offload_inc(struct tcf_block *block, u32 *flags) { if (*flags & TCA_CLS_FLAGS_IN_HW) diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c index bd0dac35e26b..8f5dfa3ffb1c 100644 --- a/net/sched/cls_api.c +++ b/net/sched/cls_api.c @@ -237,6 +237,7 @@ static struct tcf_chain *tcf_chain_create(struct tcf_block *block, if (!chain) return NULL; list_add_tail(&chain->list, &block->chain_list); + spin_lock_init(&chain->filter_chain_lock); chain->block = block; chain->index = chain_index; chain->refcnt = 1; @@ -453,9 +454,13 @@ static void tcf_chain_put_explicitly_created(struct tcf_chain *chain) static void tcf_chain_flush(struct tcf_chain *chain) { - struct tcf_proto *tp = rtnl_dereference(chain->filter_chain); + struct tcf_proto *tp; + spin_lock(&chain->filter_chain_lock); + tp = tcf_chain_dereference(chain->filter_chain, chain); tcf_chain0_head_change(chain, NULL); + spin_unlock(&chain->filter_chain_lock); + while (tp) { RCU_INIT_POINTER(chain->filter_chain, tp->next); tcf_proto_destroy(tp, NULL); @@ -548,8 +553,15 @@ tcf_chain0_head_change_cb_add(struct tcf_block *block, spin_lock(&block->lock); chain0 = block->chain0.chain; - if (chain0 && chain0->filter_chain) - tcf_chain_head_change_item(item, chain0->filter_chain); + if (chain0) { + struct tcf_proto *tp_head; + + spin_lock(&chain0->filter_chain_lock); + tp_head = tcf_chain_dereference(chain0->filter_chain, chain0); + if (tp_head) + tcf_chain_head_change_item(item, tp_head); + spin_unlock(&chain0->filter_chain_lock); + } list_add(&item->list, &block->chain0.filter_chain_list); spin_unlock(&block->lock); @@ -1251,9 +1263,10 @@ struct tcf_chain_info { struct tcf_proto __rcu *next; }; -static struct tcf_proto *tcf_chain_tp_prev(struct tcf_chain_info *chain_info) +static struct tcf_proto *tcf_chain_tp_prev(struct tcf_chain *chain, + struct tcf_chain_info *chain_info) { - return rtnl_dereference(*chain_info->pprev); + return tcf_chain_dereference(*chain_info->pprev, chain); } static void tcf_chain_tp_insert(struct tcf_chain *chain, @@ -1262,7 +1275,7 @@ static void tcf_chain_tp_insert(struct tcf_chain *chain, { if (*chain_info->pprev == chain->filter_chain) tcf_chain0_head_change(chain, tp); - RCU_INIT_POINTER(tp->next, tcf_chain_tp_prev(chain_info)); + RCU_INIT_POINTER(tp->next, tcf_chain_tp_prev(chain, chain_info)); rcu_assign_pointer(*chain_info->pprev, tp); tcf_chain_hold(chain); } @@ -1271,7 +1284,7 @@ static void tcf_chain_tp_remove(struct tcf_chain *chain, struct tcf_chain_info *chain_info, struct tcf_proto *tp) { - struct tcf_proto *next = rtnl_dereference(chain_info->next); + struct tcf_proto *next = tcf_chain_dereference(chain_info->next, chain); if (tp == chain->filter_chain) tcf_chain0_head_change(chain, next); @@ -1288,7 +1301,8 @@ static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, /* Check the chain for existence of proto-tcf with this priority */ for (pprev = &chain->filter_chain; - (tp = rtnl_dereference(*pprev)); pprev = &tp->next) { + (tp = tcf_chain_dereference(*pprev, chain)); + pprev = &tp->next) { if (tp->prio >= prio) { if (tp->prio == prio) { if (prio_allocate || @@ -1496,12 +1510,13 @@ static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, goto errout; } + spin_lock(&chain->filter_chain_lock); tp = tcf_chain_tp_find(chain, &chain_info, protocol, prio, prio_allocate); if (IS_ERR(tp)) { NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); err = PTR_ERR(tp); - goto errout; + goto errout_locked; } if (tp == NULL) { @@ -1510,29 +1525,37 @@ static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, if (tca[TCA_KIND] == NULL || !protocol) { NL_SET_ERR_MSG(extack, "Filter kind and protocol must be specified"); err = -EINVAL; - goto errout; + goto errout_locked; } if (!(n->nlmsg_flags & NLM_F_CREATE)) { NL_SET_ERR_MSG(extack, "Need both RTM_NEWTFILTER and NLM_F_CREATE to create a new filter"); err = -ENOENT; - goto errout; + goto errout_locked; } if (prio_allocate) - prio = tcf_auto_prio(tcf_chain_tp_prev(&chain_info)); + prio = tcf_auto_prio(tcf_chain_tp_prev(chain, + &chain_info)); + spin_unlock(&chain->filter_chain_lock); tp = tcf_proto_create(nla_data(tca[TCA_KIND]), protocol, prio, chain, extack); if (IS_ERR(tp)) { err = PTR_ERR(tp); goto errout; } + + spin_lock(&chain->filter_chain_lock); + tcf_chain_tp_insert(chain, &chain_info, tp); + spin_unlock(&chain->filter_chain_lock); tp_created = 1; } else if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { NL_SET_ERR_MSG(extack, "Specified filter kind does not match existing one"); err = -EINVAL; - goto errout; + goto errout_locked; + } else { + spin_unlock(&chain->filter_chain_lock); } fh = tp->ops->get(tp, t->tcm_handle); @@ -1558,21 +1581,11 @@ static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, err = tp->ops->change(net, skb, tp, cl, t->tcm_handle, tca, &fh, n->nlmsg_flags & NLM_F_CREATE ? TCA_ACT_NOREPLACE : TCA_ACT_REPLACE, extack); - if (err == 0) { - if (tp_created) - tcf_chain_tp_insert(chain, &chain_info, tp); + if (err == 0) tfilter_notify(net, skb, n, tp, block, q, parent, fh, RTM_NEWTFILTER, false); - } else { - if (tp_created) { - /* tp wasn't inserted to chain tp list. Take reference - * to chain manually for tcf_proto_destroy() to - * release. - */ - tcf_chain_hold(chain); - tcf_proto_destroy(tp, NULL); - } - } + else if (tp_created) + tcf_proto_destroy(tp, NULL); errout: if (chain) @@ -1582,6 +1595,10 @@ static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, /* Replay the request. */ goto replay; return err; + +errout_locked: + spin_unlock(&chain->filter_chain_lock); + goto errout; } static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, @@ -1657,31 +1674,34 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, goto errout; } + spin_lock(&chain->filter_chain_lock); tp = tcf_chain_tp_find(chain, &chain_info, protocol, prio, false); if (!tp || IS_ERR(tp)) { NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); err = tp ? PTR_ERR(tp) : -ENOENT; - goto errout; + goto errout_locked; } else if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { NL_SET_ERR_MSG(extack, "Specified filter kind does not match existing one"); err = -EINVAL; + goto errout_locked; + } else if (t->tcm_handle == 0) { + tcf_chain_tp_remove(chain, &chain_info, tp); + spin_unlock(&chain->filter_chain_lock); + + tfilter_notify(net, skb, n, tp, block, q, parent, fh, + RTM_DELTFILTER, false); + tcf_proto_destroy(tp, extack); + err = 0; goto errout; } + spin_unlock(&chain->filter_chain_lock); fh = tp->ops->get(tp, t->tcm_handle); if (!fh) { - if (t->tcm_handle == 0) { - tcf_chain_tp_remove(chain, &chain_info, tp); - tfilter_notify(net, skb, n, tp, block, q, parent, fh, - RTM_DELTFILTER, false); - tcf_proto_destroy(tp, extack); - err = 0; - } else { - NL_SET_ERR_MSG(extack, "Specified filter handle not found"); - err = -ENOENT; - } + NL_SET_ERR_MSG(extack, "Specified filter handle not found"); + err = -ENOENT; } else { bool last; @@ -1691,7 +1711,10 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, if (err) goto errout; if (last) { + spin_lock(&chain->filter_chain_lock); tcf_chain_tp_remove(chain, &chain_info, tp); + spin_unlock(&chain->filter_chain_lock); + tcf_proto_destroy(tp, extack); } } @@ -1701,6 +1724,10 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, tcf_chain_put(chain); tcf_block_release(q, block); return err; + +errout_locked: + spin_unlock(&chain->filter_chain_lock); + goto errout; } static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, @@ -1758,8 +1785,10 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, goto errout; } + spin_lock(&chain->filter_chain_lock); tp = tcf_chain_tp_find(chain, &chain_info, protocol, prio, false); + spin_unlock(&chain->filter_chain_lock); if (!tp || IS_ERR(tp)) { NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); err = tp ? PTR_ERR(tp) : -ENOENT; -- 2.7.5