This corresponds to the NT syscall NtCreateMutant().

An NT mutex is recursive, with a 32-bit recursion counter. When acquired via
NtWaitForMultipleObjects(), the recursion counter is incremented by one.

The OS records the thread which acquired it. However, in order to keep this
driver self-contained, the owning thread ID is managed by user-space, and passed
as a parameter to all relevant ioctls.

The initial owner and recursion count, if any, are specified when the mutex is
created.

Signed-off-by: Elizabeth Figura <zfig...@codeweavers.com>
---
 drivers/misc/ntsync.c       | 68 +++++++++++++++++++++++++++++++++++++
 include/uapi/linux/ntsync.h |  7 ++++
 2 files changed, 75 insertions(+)

diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c
index e914d626465a..173513aeeacc 100644
--- a/drivers/misc/ntsync.c
+++ b/drivers/misc/ntsync.c
@@ -24,6 +24,7 @@
 
 enum ntsync_type {
        NTSYNC_TYPE_SEM,
+       NTSYNC_TYPE_MUTEX,
 };
 
 /*
@@ -53,6 +54,10 @@ struct ntsync_obj {
                        __u32 count;
                        __u32 max;
                } sem;
+               struct {
+                       __u32 count;
+                       __u32 owner;
+               } mutex;
        } u;
 
        /*
@@ -132,6 +137,10 @@ static bool is_signaled(struct ntsync_obj *obj, __u32 
owner)
        switch (obj->type) {
        case NTSYNC_TYPE_SEM:
                return !!obj->u.sem.count;
+       case NTSYNC_TYPE_MUTEX:
+               if (obj->u.mutex.owner && obj->u.mutex.owner != owner)
+                       return false;
+               return obj->u.mutex.count < UINT_MAX;
        }
 
        WARN(1, "bad object type %#x\n", obj->type);
@@ -175,6 +184,10 @@ static void try_wake_all(struct ntsync_device *dev, struct 
ntsync_q *q,
                        case NTSYNC_TYPE_SEM:
                                obj->u.sem.count--;
                                break;
+                       case NTSYNC_TYPE_MUTEX:
+                               obj->u.mutex.count++;
+                               obj->u.mutex.owner = q->owner;
+                               break;
                        }
                }
                wake_up_process(q->task);
@@ -217,6 +230,29 @@ static void try_wake_any_sem(struct ntsync_obj *sem)
        }
 }
 
+static void try_wake_any_mutex(struct ntsync_obj *mutex)
+{
+       struct ntsync_q_entry *entry;
+
+       lockdep_assert_held(&mutex->lock);
+
+       list_for_each_entry(entry, &mutex->any_waiters, node) {
+               struct ntsync_q *q = entry->q;
+               int signaled = -1;
+
+               if (mutex->u.mutex.count == UINT_MAX)
+                       break;
+               if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner)
+                       continue;
+
+               if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
+                       mutex->u.mutex.count++;
+                       mutex->u.mutex.owner = q->owner;
+                       wake_up_process(q->task);
+               }
+       }
+}
+
 /*
  * Actually change the semaphore state, returning -EOVERFLOW if it is made
  * invalid.
@@ -376,6 +412,33 @@ static int ntsync_create_sem(struct ntsync_device *dev, 
void __user *argp)
        return put_user(fd, &user_args->sem);
 }
 
+static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp)
+{
+       struct ntsync_mutex_args __user *user_args = argp;
+       struct ntsync_mutex_args args;
+       struct ntsync_obj *mutex;
+       int fd;
+
+       if (copy_from_user(&args, argp, sizeof(args)))
+               return -EFAULT;
+
+       if (!args.owner != !args.count)
+               return -EINVAL;
+
+       mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX);
+       if (!mutex)
+               return -ENOMEM;
+       mutex->u.mutex.count = args.count;
+       mutex->u.mutex.owner = args.owner;
+       fd = ntsync_obj_get_fd(mutex);
+       if (fd < 0) {
+               kfree(mutex);
+               return fd;
+       }
+
+       return put_user(fd, &user_args->mutex);
+}
+
 static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd)
 {
        struct file *file = fget(fd);
@@ -505,6 +568,9 @@ static void try_wake_any_obj(struct ntsync_obj *obj)
        case NTSYNC_TYPE_SEM:
                try_wake_any_sem(obj);
                break;
+       case NTSYNC_TYPE_MUTEX:
+               try_wake_any_mutex(obj);
+               break;
        }
 }
 
@@ -693,6 +759,8 @@ static long ntsync_char_ioctl(struct file *file, unsigned 
int cmd,
        void __user *argp = (void __user *)parm;
 
        switch (cmd) {
+       case NTSYNC_IOC_CREATE_MUTEX:
+               return ntsync_create_mutex(dev, argp);
        case NTSYNC_IOC_CREATE_SEM:
                return ntsync_create_sem(dev, argp);
        case NTSYNC_IOC_WAIT_ALL:
diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h
index 83784d4438a1..cd7841cdba49 100644
--- a/include/uapi/linux/ntsync.h
+++ b/include/uapi/linux/ntsync.h
@@ -16,6 +16,12 @@ struct ntsync_sem_args {
        __u32 max;
 };
 
+struct ntsync_mutex_args {
+       __u32 mutex;
+       __u32 owner;
+       __u32 count;
+};
+
 #define NTSYNC_WAIT_REALTIME   0x1
 
 struct ntsync_wait_args {
@@ -34,6 +40,7 @@ struct ntsync_wait_args {
 #define NTSYNC_IOC_CREATE_SEM          _IOWR('N', 0x80, struct ntsync_sem_args)
 #define NTSYNC_IOC_WAIT_ANY            _IOWR('N', 0x82, struct 
ntsync_wait_args)
 #define NTSYNC_IOC_WAIT_ALL            _IOWR('N', 0x83, struct 
ntsync_wait_args)
+#define NTSYNC_IOC_CREATE_MUTEX                _IOWR('N', 0x84, struct 
ntsync_sem_args)
 
 #define NTSYNC_IOC_SEM_POST            _IOWR('N', 0x81, __u32)
 
-- 
2.43.0


Reply via email to