As a preparation for registering rules update netlink handlers as unlocked,
conditionally take rtnl in following cases:
- Parent qdisc doesn't support unlocked execution.
- Requested classifier type doesn't support unlocked execution.
- User requested to flash whole chain using old filter update API, instead
of new chains API.

Add helper function tcf_require_rtnl() to only lock rtnl when specified
condition is true and the lock hasn't been taken already.

Signed-off-by: Vlad Buslov <vla...@mellanox.com>
Acked-by: Jiri Pirko <j...@mellanox.com>
---
 net/sched/cls_api.c | 74 +++++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 63 insertions(+), 11 deletions(-)

diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c
index 1956c5df5f89..848f148f1019 100644
--- a/net/sched/cls_api.c
+++ b/net/sched/cls_api.c
@@ -179,9 +179,25 @@ static void tcf_proto_destroy_work(struct work_struct 
*work)
        rtnl_unlock();
 }
 
+/* Helper function to lock rtnl mutex when specified condition is true and 
mutex
+ * hasn't been locked yet. Will set rtnl_held to 'true' before taking rtnl 
lock.
+ * Note that this function does nothing if rtnl is already held. This is
+ * intended to be used by cls API rules update API when multiple conditions
+ * could require rtnl lock and its state needs to be tracked to prevent trying
+ * to obtain lock multiple times.
+ */
+
+static void tcf_require_rtnl(bool cond, bool *rtnl_held)
+{
+       if (!*rtnl_held && cond) {
+               *rtnl_held = true;
+               rtnl_lock();
+       }
+}
+
 static struct tcf_proto *tcf_proto_create(const char *kind, u32 protocol,
                                          u32 prio, struct tcf_chain *chain,
-                                         bool rtnl_held,
+                                         bool *rtnl_held,
                                          struct netlink_ext_ack *extack)
 {
        struct tcf_proto *tp;
@@ -191,7 +207,7 @@ static struct tcf_proto *tcf_proto_create(const char *kind, 
u32 protocol,
        if (!tp)
                return ERR_PTR(-ENOBUFS);
 
-       tp->ops = tcf_proto_lookup_ops(kind, rtnl_held, extack);
+       tp->ops = tcf_proto_lookup_ops(kind, *rtnl_held, extack);
        if (IS_ERR(tp->ops)) {
                err = PTR_ERR(tp->ops);
                goto errout;
@@ -204,6 +220,8 @@ static struct tcf_proto *tcf_proto_create(const char *kind, 
u32 protocol,
        spin_lock_init(&tp->lock);
        refcount_set(&tp->refcnt, 1);
 
+       tcf_require_rtnl(!(tp->ops->flags & TCF_PROTO_OPS_DOIT_UNLOCKED),
+                        rtnl_held);
        err = tp->ops->init(tp);
        if (err) {
                module_put(tp->ops->owner);
@@ -888,6 +906,7 @@ static void tcf_block_refcnt_put(struct tcf_block *block)
 static struct tcf_block *tcf_block_find(struct net *net, struct Qdisc **q,
                                        u32 *parent, unsigned long *cl,
                                        int ifindex, u32 block_index,
+                                       bool *rtnl_held,
                                        struct netlink_ext_ack *extack)
 {
        struct tcf_block *block;
@@ -953,6 +972,12 @@ static struct tcf_block *tcf_block_find(struct net *net, 
struct Qdisc **q,
                 */
                rcu_read_unlock();
 
+               /* Take rtnl mutex if qdisc doesn't support unlocked
+                * execution.
+                */
+               tcf_require_rtnl(!(cops->flags & QDISC_CLASS_OPS_DOIT_UNLOCKED),
+                                rtnl_held);
+
                /* Do we search for filter, attached to class? */
                if (TC_H_MIN(*parent)) {
                        *cl = cops->find(*q, *parent);
@@ -990,7 +1015,10 @@ static struct tcf_block *tcf_block_find(struct net *net, 
struct Qdisc **q,
        rcu_read_unlock();
 errout_qdisc:
        if (*q) {
-               qdisc_put(*q);
+               if (*rtnl_held)
+                       qdisc_put(*q);
+               else
+                       qdisc_put_unlocked(*q);
                *q = NULL;
        }
        return ERR_PTR(err);
@@ -1678,7 +1706,7 @@ static int tc_new_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
        void *fh;
        int err;
        int tp_created;
-       bool rtnl_held;
+       bool rtnl_held = true;
 
        if (!netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN))
                return -EPERM;
@@ -1697,7 +1725,6 @@ static int tc_new_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
        parent = t->tcm_parent;
        tp = NULL;
        cl = 0;
-       rtnl_held = true;
 
        if (prio == 0) {
                /* If no priority is provided by the user,
@@ -1715,7 +1742,8 @@ static int tc_new_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
        /* Find head of filter chain. */
 
        block = tcf_block_find(net, &q, &parent, &cl,
-                              t->tcm_ifindex, t->tcm_block_index, extack);
+                              t->tcm_ifindex, t->tcm_block_index, &rtnl_held,
+                              extack);
        if (IS_ERR(block)) {
                err = PTR_ERR(block);
                goto errout;
@@ -1766,7 +1794,7 @@ static int tc_new_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
 
                spin_unlock(&chain->filter_chain_lock);
                tp_new = tcf_proto_create(nla_data(tca[TCA_KIND]),
-                                         protocol, prio, chain, rtnl_held,
+                                         protocol, prio, chain, &rtnl_held,
                                          extack);
                if (IS_ERR(tp_new)) {
                        err = PTR_ERR(tp_new);
@@ -1788,6 +1816,10 @@ static int tc_new_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
                spin_unlock(&chain->filter_chain_lock);
        }
 
+       /* take rtnl mutex if classifier doesn't support unlocked execution */
+       tcf_require_rtnl(!(tp->ops->flags & TCF_PROTO_OPS_DOIT_UNLOCKED),
+                        &rtnl_held);
+
        if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) {
                NL_SET_ERR_MSG(extack, "Specified filter kind does not match 
existing one");
                err = -EINVAL;
@@ -1834,9 +1866,14 @@ static int tc_new_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
                        tcf_chain_put(chain);
        }
        tcf_block_release(q, block);
-       if (err == -EAGAIN)
+       if (err == -EAGAIN) {
+               /* Take rtnl lock in case EAGAIN is caused by concurrent flush
+                * of target chain.
+                */
+               tcf_require_rtnl(true, &rtnl_held);
                /* Replay the request. */
                goto replay;
+       }
        return err;
 
 errout_locked:
@@ -1881,10 +1918,16 @@ static int tc_del_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
                return -ENOENT;
        }
 
+       /* Always take rtnl mutex when flushing whole chain in order to
+        * synchronize with chain locked API.
+        */
+       tcf_require_rtnl(!prio, &rtnl_held);
+
        /* Find head of filter chain. */
 
        block = tcf_block_find(net, &q, &parent, &cl,
-                              t->tcm_ifindex, t->tcm_block_index, extack);
+                              t->tcm_ifindex, t->tcm_block_index, &rtnl_held,
+                              extack);
        if (IS_ERR(block)) {
                err = PTR_ERR(block);
                goto errout;
@@ -1941,6 +1984,9 @@ static int tc_del_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
        }
        spin_unlock(&chain->filter_chain_lock);
 
+       /* take rtnl mutex if classifier doesn't support unlocked execution */
+       tcf_require_rtnl(!(tp->ops->flags & TCF_PROTO_OPS_DOIT_UNLOCKED),
+                        &rtnl_held);
        fh = tp->ops->get(tp, t->tcm_handle);
 
        if (!fh) {
@@ -2010,7 +2056,8 @@ static int tc_get_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
        /* Find head of filter chain. */
 
        block = tcf_block_find(net, &q, &parent, &cl,
-                              t->tcm_ifindex, t->tcm_block_index, extack);
+                              t->tcm_ifindex, t->tcm_block_index, &rtnl_held,
+                              extack);
        if (IS_ERR(block)) {
                err = PTR_ERR(block);
                goto errout;
@@ -2043,6 +2090,9 @@ static int tc_get_tfilter(struct sk_buff *skb, struct 
nlmsghdr *n,
                goto errout;
        }
 
+       /* take rtnl mutex if classifier doesn't support unlocked execution */
+       tcf_require_rtnl(!(tp->ops->flags & TCF_PROTO_OPS_DOIT_UNLOCKED),
+                        &rtnl_held);
        fh = tp->ops->get(tp, t->tcm_handle);
 
        if (!fh) {
@@ -2397,6 +2447,7 @@ static int tc_ctl_chain(struct sk_buff *skb, struct 
nlmsghdr *n,
        struct Qdisc *q = NULL;
        struct tcf_chain *chain = NULL;
        struct tcf_block *block;
+       bool rtnl_held = true;
        unsigned long cl;
        int err;
 
@@ -2414,7 +2465,8 @@ static int tc_ctl_chain(struct sk_buff *skb, struct 
nlmsghdr *n,
        cl = 0;
 
        block = tcf_block_find(net, &q, &parent, &cl,
-                              t->tcm_ifindex, t->tcm_block_index, extack);
+                              t->tcm_ifindex, t->tcm_block_index, &rtnl_held,
+                              extack);
        if (IS_ERR(block))
                return PTR_ERR(block);
 
-- 
2.7.5

Reply via email to