A peer connected via UDP may change its IP address without reconnecting
(float).

Add support for detecting and updating the new peer IP/port in case of
floating.

Signed-off-by: Antonio Quartulli <anto...@openvpn.net>
---
 drivers/net/ovpn/bind.c |  10 ++--
 drivers/net/ovpn/io.c   |   9 ++++
 drivers/net/ovpn/peer.c | 128 ++++++++++++++++++++++++++++++++++++++++++++++--
 drivers/net/ovpn/peer.h |   2 +
 4 files changed, 138 insertions(+), 11 deletions(-)

diff --git a/drivers/net/ovpn/bind.c b/drivers/net/ovpn/bind.c
index 
b4d2ccec2ceddf43bc445b489cc62a578ef0ad0a..d17d078c5730bf4336dc87f45cdba3f6b8cad770
 100644
--- a/drivers/net/ovpn/bind.c
+++ b/drivers/net/ovpn/bind.c
@@ -47,12 +47,8 @@ struct ovpn_bind *ovpn_bind_from_sockaddr(const struct 
sockaddr_storage *ss)
  * @new: the new bind to assign
  */
 void ovpn_bind_reset(struct ovpn_peer *peer, struct ovpn_bind *new)
+       __must_hold(&peer->lock)
 {
-       struct ovpn_bind *old;
-
-       spin_lock_bh(&peer->lock);
-       old = rcu_replace_pointer(peer->bind, new, true);
-       spin_unlock_bh(&peer->lock);
-
-       kfree_rcu(old, rcu);
+       kfree_rcu(rcu_replace_pointer(peer->bind, new,
+                                     lockdep_is_held(&peer->lock)), rcu);
 }
diff --git a/drivers/net/ovpn/io.c b/drivers/net/ovpn/io.c
index 
4e69f31382d2cb9ce4bc40f06cfbae47add5b5ba..8f2b4a85d20fbdb512de7ec312d391985e96b906
 100644
--- a/drivers/net/ovpn/io.c
+++ b/drivers/net/ovpn/io.c
@@ -133,6 +133,15 @@ void ovpn_decrypt_post(void *data, int ret)
        /* keep track of last received authenticated packet for keepalive */
        peer->last_recv = ktime_get_real_seconds();
 
+       if (peer->sock->sock->sk->sk_protocol == IPPROTO_UDP) {
+               /* check if this peer changed it's IP address and update
+                * state
+                */
+               ovpn_peer_float(peer, skb);
+               /* update source endpoint for this peer */
+               ovpn_peer_update_local_endpoint(peer, skb);
+       }
+
        /* point to encapsulated IP packet */
        __skb_pull(skb, payload_offset);
 
diff --git a/drivers/net/ovpn/peer.c b/drivers/net/ovpn/peer.c
index 
f9d8f1d1827fe67dc4b4e0bba41a5b110bb90819..891cf2fa6a81b46ac764ad4fd50d4456fa7ce5bd
 100644
--- a/drivers/net/ovpn/peer.c
+++ b/drivers/net/ovpn/peer.c
@@ -94,6 +94,128 @@ struct ovpn_peer *ovpn_peer_new(struct ovpn_struct *ovpn, 
u32 id)
        return peer;
 }
 
+/**
+ * ovpn_peer_reset_sockaddr - recreate binding for peer
+ * @peer: peer to recreate the binding for
+ * @ss: sockaddr to use as remote endpoint for the binding
+ * @local_ip: local IP for the binding
+ *
+ * Return: 0 on success or a negative error code otherwise
+ */
+static int ovpn_peer_reset_sockaddr(struct ovpn_peer *peer,
+                                   const struct sockaddr_storage *ss,
+                                   const u8 *local_ip)
+       __must_hold(&peer->lock)
+{
+       struct ovpn_bind *bind;
+       size_t ip_len;
+
+       /* create new ovpn_bind object */
+       bind = ovpn_bind_from_sockaddr(ss);
+       if (IS_ERR(bind))
+               return PTR_ERR(bind);
+
+       if (local_ip) {
+               if (ss->ss_family == AF_INET) {
+                       ip_len = sizeof(struct in_addr);
+               } else if (ss->ss_family == AF_INET6) {
+                       ip_len = sizeof(struct in6_addr);
+               } else {
+                       netdev_dbg(peer->ovpn->dev, "%s: invalid family for 
remote endpoint\n",
+                                  __func__);
+                       kfree(bind);
+                       return -EINVAL;
+               }
+
+               memcpy(&bind->local, local_ip, ip_len);
+       }
+
+       /* set binding */
+       ovpn_bind_reset(peer, bind);
+
+       return 0;
+}
+
+#define ovpn_get_hash_head(_tbl, _key, _key_len) ({            \
+       typeof(_tbl) *__tbl = &(_tbl);                          \
+       (&(*__tbl)[jhash(_key, _key_len, 0) % HASH_SIZE(*__tbl)]); }) \
+
+/**
+ * ovpn_peer_float - update remote endpoint for peer
+ * @peer: peer to update the remote endpoint for
+ * @skb: incoming packet to retrieve the source address (remote) from
+ */
+void ovpn_peer_float(struct ovpn_peer *peer, struct sk_buff *skb)
+{
+       struct hlist_nulls_head *nhead;
+       struct sockaddr_storage ss;
+       const u8 *local_ip = NULL;
+       struct sockaddr_in6 *sa6;
+       struct sockaddr_in *sa;
+       struct ovpn_bind *bind;
+       sa_family_t family;
+       size_t salen;
+
+       rcu_read_lock();
+       bind = rcu_dereference(peer->bind);
+       if (unlikely(!bind)) {
+               rcu_read_unlock();
+               return;
+       }
+
+       spin_lock_bh(&peer->lock);
+       if (likely(ovpn_bind_skb_src_match(bind, skb)))
+               goto unlock;
+
+       family = skb_protocol_to_family(skb);
+
+       if (bind->remote.in4.sin_family == family)
+               local_ip = (u8 *)&bind->local;
+
+       switch (family) {
+       case AF_INET:
+               sa = (struct sockaddr_in *)&ss;
+               sa->sin_family = AF_INET;
+               sa->sin_addr.s_addr = ip_hdr(skb)->saddr;
+               sa->sin_port = udp_hdr(skb)->source;
+               salen = sizeof(*sa);
+               break;
+       case AF_INET6:
+               sa6 = (struct sockaddr_in6 *)&ss;
+               sa6->sin6_family = AF_INET6;
+               sa6->sin6_addr = ipv6_hdr(skb)->saddr;
+               sa6->sin6_port = udp_hdr(skb)->source;
+               sa6->sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr,
+                                                        skb->skb_iif);
+               salen = sizeof(*sa6);
+               break;
+       default:
+               goto unlock;
+       }
+
+       netdev_dbg(peer->ovpn->dev, "%s: peer %d floated to %pIScp", __func__,
+                  peer->id, &ss);
+       ovpn_peer_reset_sockaddr(peer, (struct sockaddr_storage *)&ss,
+                                local_ip);
+
+       /* rehashing is required only in MP mode as P2P has one peer
+        * only and thus there is no hashtable
+        */
+       if (peer->ovpn->mode == OVPN_MODE_MP) {
+               spin_lock_bh(&peer->ovpn->peers->lock);
+               /* remove old hashing */
+               hlist_nulls_del_init_rcu(&peer->hash_entry_transp_addr);
+               /* re-add with new transport address */
+               nhead = ovpn_get_hash_head(peer->ovpn->peers->by_transp_addr,
+                                          &ss, salen);
+               hlist_nulls_add_head_rcu(&peer->hash_entry_transp_addr, nhead);
+               spin_unlock_bh(&peer->ovpn->peers->lock);
+       }
+unlock:
+       spin_unlock_bh(&peer->lock);
+       rcu_read_unlock();
+}
+
 /**
  * ovpn_peer_release_rcu - release peer private members
  * @head: RCU head belonging to peer being released
@@ -103,7 +225,9 @@ static void ovpn_peer_release_rcu(struct rcu_head *head)
        struct ovpn_peer *peer = container_of(head, struct ovpn_peer, rcu);
 
        ovpn_crypto_state_release(&peer->crypto);
+       spin_lock_bh(&peer->lock);
        ovpn_bind_reset(peer, NULL);
+       spin_unlock_bh(&peer->lock);
 
        dst_cache_destroy(&peer->dst_cache);
 }
@@ -188,10 +312,6 @@ static struct in6_addr ovpn_nexthop_from_skb6(struct 
sk_buff *skb)
        return rt->rt6i_gateway;
 }
 
-#define ovpn_get_hash_head(_tbl, _key, _key_len) ({            \
-       typeof(_tbl) *__tbl = &(_tbl);                          \
-       (&(*__tbl)[jhash(_key, _key_len, 0) % HASH_SIZE(*__tbl)]); }) \
-
 /**
  * ovpn_peer_get_by_vpn_addr4 - retrieve peer by its VPN IPv4 address
  * @ovpn: the openvpn instance to search
diff --git a/drivers/net/ovpn/peer.h b/drivers/net/ovpn/peer.h
index 
6b66e169b33510a794f8f43dff757f0357e3e5de..ea1a014568c64d796bb18447ceb70e801bfaf3f2
 100644
--- a/drivers/net/ovpn/peer.h
+++ b/drivers/net/ovpn/peer.h
@@ -162,4 +162,6 @@ void ovpn_peer_keepalive_work(struct work_struct *work);
 void ovpn_peer_update_local_endpoint(struct ovpn_peer *peer,
                                     struct sk_buff *skb);
 
+void ovpn_peer_float(struct ovpn_peer *peer, struct sk_buff *skb);
+
 #endif /* _NET_OVPN_OVPNPEER_H_ */

-- 
2.45.2


Reply via email to