On Wed, Feb 12, 2020 at 01:43:32PM -0500, Marcelo Diop-Gonzalez wrote:
> Currently reference counts are implemented by locking service_spinlock
> and then incrementing the service's ->ref_count field, calling
> kfree() when the last reference has been dropped. But at the same
> time, there's code in multiple places that dereferences pointers
> to services without having a reference, so there could be a race there.
> 
> It should be possible to avoid taking any lock in unlock_service()
> or service_release() because we are setting a single array element
> to NULL, and on service creation, a mutex is locked before looking
> for a NULL spot to put the new service in.
> 
> Using a struct kref and RCU-delaying the freeing of services fixes
> this race condition while still making it possible to skip
> grabbing a reference in many places. Also it avoids the need to
> acquire a single spinlock when e.g. taking a reference on
> state->services[i] when somebody else is in the middle of taking
> a reference on state->services[j].
> 
> Signed-off-by: Marcelo Diop-Gonzalez <marcgonza...@google.com>
> ---
>  .../interface/vchiq_arm/vchiq_arm.c           |  25 +-
>  .../interface/vchiq_arm/vchiq_core.c          | 222 +++++++++---------
>  .../interface/vchiq_arm/vchiq_core.h          |  12 +-
>  3 files changed, 140 insertions(+), 119 deletions(-)
> 
> diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c 
> b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
> index c456ced431af..3ed0e4ea7f5c 100644
> --- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
> +++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
> @@ -22,6 +22,7 @@
>  #include <linux/platform_device.h>
>  #include <linux/compat.h>
>  #include <linux/dma-mapping.h>
> +#include <linux/rcupdate.h>
>  #include <soc/bcm2835/raspberrypi-firmware.h>
>  
>  #include "vchiq_core.h"
> @@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context)
>       /* There is no list of instances, so instead scan all services,
>               marking those that have been dumped. */
>  
> +     rcu_read_lock();
>       for (i = 0; i < state->unused_service; i++) {
> -             struct vchiq_service *service = state->services[i];
> +             struct vchiq_service *service;
>               struct vchiq_instance *instance;
>  
> +             service = rcu_dereference(state->services[i]);
>               if (!service || service->base.callback != service_callback)
>                       continue;
>  
> @@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context)
>               if (instance)
>                       instance->mark = 0;
>       }
> +     rcu_read_unlock();
>  
>       for (i = 0; i < state->unused_service; i++) {
> -             struct vchiq_service *service = state->services[i];
> +             struct vchiq_service *service;
>               struct vchiq_instance *instance;
>               int err;
>  
> -             if (!service || service->base.callback != service_callback)
> +             rcu_read_lock();
> +             service = rcu_dereference(state->services[i]);
> +             if (!service || service->base.callback != service_callback) {
> +                     rcu_read_unlock();
>                       continue;
> +             }
>  
>               instance = service->instance;
> -             if (!instance || instance->mark)
> +             if (!instance || instance->mark) {
> +                     rcu_read_unlock();
>                       continue;
> +             }
> +             rcu_read_unlock();
>  
>               len = snprintf(buf, sizeof(buf),
>                              "Instance %pK: pid %d,%s completions %d/%d",
> @@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context)
>                              instance->completion_insert -
>                              instance->completion_remove,
>                              MAX_COMPLETIONS);
> -
>               err = vchiq_dump(dump_context, buf, len + 1);
>               if (err)
>                       return err;
> @@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
>       if (active_services > MAX_SERVICES)
>               only_nonzero = 1;
>  
> +     rcu_read_lock();
>       for (i = 0; i < active_services; i++) {
> -             struct vchiq_service *service_ptr = state->services[i];
> +             struct vchiq_service *service_ptr =
> +                     rcu_dereference(state->services[i]);
>  
>               if (!service_ptr)
>                       continue;
> @@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
>               if (found >= MAX_SERVICES)
>                       break;
>       }
> +     rcu_read_unlock();
>  
>       read_unlock_bh(&arm_state->susp_res_lock);
>  
> diff --git a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c 
> b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
> index b2d9013b7f79..65270a5b29db 100644
> --- a/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
> +++ b/drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
> @@ -1,6 +1,9 @@
>  // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
>  /* Copyright (c) 2010-2012 Broadcom. All rights reserved. */
>  
> +#include <linux/kref.h>
> +#include <linux/rcupdate.h>
> +
>  #include "vchiq_core.h"
>  
>  #define VCHIQ_SLOT_HANDLER_STACK 8192
> @@ -54,7 +57,6 @@ int vchiq_core_log_level = VCHIQ_LOG_DEFAULT;
>  int vchiq_core_msg_log_level = VCHIQ_LOG_DEFAULT;
>  int vchiq_sync_log_level = VCHIQ_LOG_DEFAULT;
>  
> -static DEFINE_SPINLOCK(service_spinlock);
>  DEFINE_SPINLOCK(bulk_waiter_spinlock);
>  static DEFINE_SPINLOCK(quota_spinlock);
>  
> @@ -136,44 +138,41 @@ find_service_by_handle(unsigned int handle)
>  {
>       struct vchiq_service *service;
>  
> -     spin_lock(&service_spinlock);
> +     rcu_read_lock();
>       service = handle_to_service(handle);
>       if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
> -         service->handle == handle) {
> -             WARN_ON(service->ref_count == 0);
> -             service->ref_count++;
> -     } else
> -             service = NULL;
> -     spin_unlock(&service_spinlock);
> -
> -     if (!service)
> -             vchiq_log_info(vchiq_core_log_level,
> -                     "Invalid service handle 0x%x", handle);
> -
> -     return service;
> +         service->handle == handle &&
> +         kref_get_unless_zero(&service->ref_count)) {
> +             service = rcu_pointer_handoff(service);
> +             rcu_read_unlock();
> +             return service;
> +     }
> +     rcu_read_unlock();
> +     vchiq_log_info(vchiq_core_log_level,
> +                    "Invalid service handle 0x%x", handle);
> +     return NULL;
>  }
>  
>  struct vchiq_service *
>  find_service_by_port(struct vchiq_state *state, int localport)
>  {
> -     struct vchiq_service *service = NULL;
>  
>       if ((unsigned int)localport <= VCHIQ_PORT_MAX) {
> -             spin_lock(&service_spinlock);
> -             service = state->services[localport];
> -             if (service && service->srvstate != VCHIQ_SRVSTATE_FREE) {
> -                     WARN_ON(service->ref_count == 0);
> -                     service->ref_count++;
> -             } else
> -                     service = NULL;
> -             spin_unlock(&service_spinlock);
> -     }
> -
> -     if (!service)
> -             vchiq_log_info(vchiq_core_log_level,
> -                     "Invalid port %d", localport);
> +             struct vchiq_service *service;
>  
> -     return service;
> +             rcu_read_lock();
> +             service = rcu_dereference(state->services[localport]);
> +             if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
> +                 kref_get_unless_zero(&service->ref_count)) {
> +                     service = rcu_pointer_handoff(service);
> +                     rcu_read_unlock();
> +                     return service;
> +             }
> +             rcu_read_unlock();
> +     }
> +     vchiq_log_info(vchiq_core_log_level,
> +                    "Invalid port %d", localport);
> +     return NULL;
>  }
>  
>  struct vchiq_service *
> @@ -182,22 +181,20 @@ find_service_for_instance(struct vchiq_instance 
> *instance,
>  {
>       struct vchiq_service *service;
>  
> -     spin_lock(&service_spinlock);
> +     rcu_read_lock();
>       service = handle_to_service(handle);
>       if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
>           service->handle == handle &&
> -         service->instance == instance) {
> -             WARN_ON(service->ref_count == 0);
> -             service->ref_count++;
> -     } else
> -             service = NULL;
> -     spin_unlock(&service_spinlock);
> -
> -     if (!service)
> -             vchiq_log_info(vchiq_core_log_level,
> -                     "Invalid service handle 0x%x", handle);
> -
> -     return service;
> +         service->instance == instance &&
> +         kref_get_unless_zero(&service->ref_count)) {
> +             service = rcu_pointer_handoff(service);
> +             rcu_read_unlock();
> +             return service;
> +     }
> +     rcu_read_unlock();
> +     vchiq_log_info(vchiq_core_log_level,
> +                    "Invalid service handle 0x%x", handle);
> +     return NULL;
>  }
>  
>  struct vchiq_service *
> @@ -206,23 +203,21 @@ find_closed_service_for_instance(struct vchiq_instance 
> *instance,
>  {
>       struct vchiq_service *service;
>  
> -     spin_lock(&service_spinlock);
> +     rcu_read_lock();
>       service = handle_to_service(handle);
>       if (service &&
>           (service->srvstate == VCHIQ_SRVSTATE_FREE ||
>            service->srvstate == VCHIQ_SRVSTATE_CLOSED) &&
>           service->handle == handle &&
> -         service->instance == instance) {
> -             WARN_ON(service->ref_count == 0);
> -             service->ref_count++;
> -     } else
> -             service = NULL;
> -     spin_unlock(&service_spinlock);
> -
> -     if (!service)
> -             vchiq_log_info(vchiq_core_log_level,
> -                     "Invalid service handle 0x%x", handle);
> -
> +         service->instance == instance &&
> +         kref_get_unless_zero(&service->ref_count)) {
> +             service = rcu_pointer_handoff(service);
> +             rcu_read_unlock();
> +             return service;
> +     }
> +     rcu_read_unlock();
> +     vchiq_log_info(vchiq_core_log_level,
> +                    "Invalid service handle 0x%x", handle);
>       return service;
>  }
>  
> @@ -233,19 +228,19 @@ next_service_by_instance(struct vchiq_state *state, 
> struct vchiq_instance *insta
>       struct vchiq_service *service = NULL;
>       int idx = *pidx;
>  
> -     spin_lock(&service_spinlock);
> +     rcu_read_lock();
>       while (idx < state->unused_service) {
> -             struct vchiq_service *srv = state->services[idx++];
> +             struct vchiq_service *srv;
>  
> +             srv = rcu_dereference(state->services[idx++]);
>               if (srv && srv->srvstate != VCHIQ_SRVSTATE_FREE &&
> -                 srv->instance == instance) {
> -                     service = srv;
> -                     WARN_ON(service->ref_count == 0);
> -                     service->ref_count++;
> +                 srv->instance == instance &&
> +                 kref_get_unless_zero(&srv->ref_count)) {
> +                     service = rcu_pointer_handoff(srv);
>                       break;
>               }
>       }
> -     spin_unlock(&service_spinlock);
> +     rcu_read_unlock();
>  
>       *pidx = idx;
>  
> @@ -255,43 +250,34 @@ next_service_by_instance(struct vchiq_state *state, 
> struct vchiq_instance *insta
>  void
>  lock_service(struct vchiq_service *service)
>  {
> -     spin_lock(&service_spinlock);
> -     WARN_ON(!service);
> -     if (service) {
> -             WARN_ON(service->ref_count == 0);
> -             service->ref_count++;
> +     if (!service) {
> +             WARN(1, "%s service is NULL\n", __func__);
> +             return;
>       }
> -     spin_unlock(&service_spinlock);
> +     kref_get(&service->ref_count);
> +}
> +
> +static void service_release(struct kref *kref)
> +{
> +     struct vchiq_service *service =
> +             container_of(kref, struct vchiq_service, ref_count);
> +     struct vchiq_state *state = service->state;
> +
> +     WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
> +     rcu_assign_pointer(state->services[service->localport], NULL);
> +     if (service->userdata_term)
> +             service->userdata_term(service->base.userdata);
> +     kfree_rcu(service, rcu);
>  }

I think that's the first time I've seen krefs used with rcu.

It looks sane at first glance, but it's a lot of tricky changes, so I'll
assume you tested this and go merge it to see what breaks :)

thanks for doing this,

greg k-h
_______________________________________________
devel mailing list
de...@linuxdriverproject.org
http://driverdev.linuxdriverproject.org/mailman/listinfo/driverdev-devel

Reply via email to