Convert the struct socket pointer into an uint64_t and use it as id for
the new socket to pass to the backend.
Signed-off-by: Stefano Stabellini <stef...@aporeto.com>
CC: boris.ostrov...@oracle.com
CC: jgr...@suse.com
---
drivers/xen/pvcalls-front.c | 177 +++++++++++++++++++++++++++++++++++++++++---
drivers/xen/pvcalls-front.h | 2 +
2 files changed, 168 insertions(+), 11 deletions(-)
diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c
index d1dbcf1..d0f5f42 100644
--- a/drivers/xen/pvcalls-front.c
+++ b/drivers/xen/pvcalls-front.c
@@ -13,6 +13,10 @@
*/
#include <linux/module.h>
+#include <linux/net.h>
+#include <linux/socket.h>
+
+#include <net/sock.h>
#include <xen/events.h>
#include <xen/grant_table.h>
@@ -40,6 +44,24 @@ struct pvcalls_bedata {
};
struct xenbus_device *pvcalls_front_dev;
+struct sock_mapping {
+ bool active_socket;
+ struct list_head list;
+ struct socket *sock;
+ union {
+ struct {
+ int irq;
+ grant_ref_t ref;
+ struct pvcalls_data_intf *ring;
+ struct pvcalls_data data;
+ struct mutex in_mutex;
+ struct mutex out_mutex;
+
+ wait_queue_head_t inflight_conn_req;
+ } active;
+ };
+};
+
static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id)
{
struct xenbus_device *dev = dev_id;
@@ -84,6 +106,18 @@ static irqreturn_t pvcalls_front_event_handler(int irq,
void *dev_id)
return IRQ_HANDLED;
}
+static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map)
+{
+ struct sock_mapping *map = sock_map;
+
+ if (map == NULL)
+ return IRQ_HANDLED;
+
+ wake_up_interruptible(&map->active.inflight_conn_req);
+
+ return IRQ_HANDLED;
+}
+
int pvcalls_front_socket(struct socket *sock)
{
struct pvcalls_bedata *bedata;
@@ -137,6 +171,127 @@ int pvcalls_front_socket(struct socket *sock)
return ret;
}
+static struct sock_mapping *create_active(int *evtchn)
+{
+ struct sock_mapping *map = NULL;
+ void *bytes;
+ int ret, irq = -1, i;
+
+ map = kzalloc(sizeof(*map), GFP_KERNEL);
+ if (map == NULL)
+ return NULL;
+
+ init_waitqueue_head(&map->active.inflight_conn_req);
+
+ map->active.ring = (struct pvcalls_data_intf *)
+ __get_free_page(GFP_KERNEL | __GFP_ZERO);
+ if (map->active.ring == NULL)
+ goto out_error;
+ memset(map->active.ring, 0, XEN_PAGE_SIZE);
+ map->active.ring->ring_order = RING_ORDER;
+ bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO,
+ map->active.ring->ring_order);
+ if (bytes == NULL)
+ goto out_error;
+ for (i = 0; i < (1 << map->active.ring->ring_order); i++)
+ map->active.ring->ref[i] = gnttab_grant_foreign_access(
+ pvcalls_front_dev->otherend_id,
+ pfn_to_gfn(virt_to_pfn(bytes) + i), 0);
+
+ map->active.ref = gnttab_grant_foreign_access(
+ pvcalls_front_dev->otherend_id,
+ pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0);
+
+ map->active.data.in = bytes;
+ map->active.data.out = bytes +
+ XEN_FLEX_RING_SIZE(map->active.ring->ring_order);
+
+ ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn);
+ if (ret)
+ goto out_error;
+ irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler,
+ 0, "pvcalls-frontend", map);
+ if (irq < 0)
+ goto out_error;
+
+ map->active.irq = irq;
+ map->active_socket = true;
+ mutex_init(&map->active.in_mutex);
+ mutex_init(&map->active.out_mutex);
+
+ return map;
+
+out_error:
+ if (irq >= 0)
+ unbind_from_irqhandler(irq, map);
+ else if (*evtchn >= 0)
+ xenbus_free_evtchn(pvcalls_front_dev, *evtchn);
+ kfree(map->active.data.in);
+ kfree(map->active.ring);
+ kfree(map);
+ return NULL;
+}
+
+int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
+ int addr_len, int flags)
+{
+ struct pvcalls_bedata *bedata;
+ struct sock_mapping *map = NULL;
+ struct xen_pvcalls_request *req;
+ int notify, req_id, ret, evtchn;
+
+ if (!pvcalls_front_dev)
+ return -ENETUNREACH;
+ if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
+ return -ENOTSUPP;
+
+ bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
+
+ spin_lock(&bedata->pvcallss_lock);
+ req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
+ if (RING_FULL(&bedata->ring) ||
+ READ_ONCE(bedata->rsp[req_id].req_id) != PVCALLS_INVALID_ID) {
+ spin_unlock(&bedata->pvcallss_lock);
+ return -EAGAIN;
+ }
+
+ map = create_active(&evtchn);
+ if (!map) {
+ spin_unlock(&bedata->pvcallss_lock);
+ return -ENOMEM;
+ }
+
+ req = RING_GET_REQUEST(&bedata->ring, req_id);
+ req->req_id = req_id;
+ req->cmd = PVCALLS_CONNECT;
+ req->u.connect.id = (uint64_t)sock;
+ memcpy(req->u.connect.addr, addr, sizeof(*addr));
+ req->u.connect.len = addr_len;
+ req->u.connect.flags = flags;
+ req->u.connect.ref = map->active.ref;
+ req->u.connect.evtchn = evtchn;
+
+ list_add_tail(&map->list, &bedata->socket_mappings);
+ map->sock = sock;
+ WRITE_ONCE(sock->sk->sk_send_head, (void *)map);
+
+ bedata->ring.req_prod_pvt++;
+ RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
+ spin_unlock(&bedata->pvcallss_lock);
+
+ if (notify)
+ notify_remote_via_irq(bedata->irq);
+
+ wait_event(bedata->inflight_req,
+ READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
+
+ ret = bedata->rsp[req_id].ret;
+ /* read ret, then set this rsp slot to be reused */
+ smp_mb();
+ WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID);
+ return ret;
+}
+
static const struct xenbus_device_id pvcalls_front_ids[] = {
{ "pvcalls" },
{ "" }
@@ -150,7 +305,7 @@ static int pvcalls_front_remove(struct xenbus_device *dev)
static int pvcalls_front_probe(struct xenbus_device *dev,
const struct xenbus_device_id *id)
{
- int ret = -EFAULT, evtchn, ref = -1, i;
+ int ret = -ENOMEM, evtchn, ref = -1, i;
unsigned int max_page_order, function_calls, len;
char *versions;
grant_ref_t gref_head = 0;
@@ -171,15 +326,13 @@ static int pvcalls_front_probe(struct xenbus_device *dev,
return -EINVAL;
}
kfree(versions);
- ret = xenbus_scanf(XBT_NIL, dev->otherend,
- "max-page-order", "%u", &max_page_order);
- if (ret <= 0)
- return -ENODEV;
+ max_page_order = xenbus_read_unsigned(dev->otherend,
+ "max-page-order", 0);
if (max_page_order < RING_ORDER)
return -ENODEV;
- ret = xenbus_scanf(XBT_NIL, dev->otherend,
- "function-calls", "%u", &function_calls);
- if (ret <= 0 || function_calls != 1)
+ function_calls = xenbus_read_unsigned(dev->otherend,
+ "function-calls", 0);
+ if (function_calls != 1)
return -ENODEV;
pr_info("%s max-page-order is %u\n", __func__, max_page_order);
@@ -187,6 +340,8 @@ static int pvcalls_front_probe(struct xenbus_device *dev,
if (!bedata)
return -ENOMEM;
+ dev_set_drvdata(&dev->dev, bedata);
+ pvcalls_front_dev = dev;
init_waitqueue_head(&bedata->inflight_req);
for (i = 0; i < PVCALLS_NR_REQ_PER_RING; i++)
bedata->rsp[i].req_id = PVCALLS_INVALID_ID;
@@ -214,8 +369,10 @@ static int pvcalls_front_probe(struct xenbus_device *dev,
if (ret < 0)
goto error;
bedata->ref = ref = gnttab_claim_grant_reference(&gref_head);
- if (ref < 0)
+ if (ref < 0) {
+ ret = ref;
goto error;
+ }
gnttab_grant_foreign_access_ref(ref, dev->otherend_id,
virt_to_gfn((void *)sring), 0);
@@ -246,8 +403,6 @@ static int pvcalls_front_probe(struct xenbus_device *dev,
INIT_LIST_HEAD(&bedata->socket_mappings);
INIT_LIST_HEAD(&bedata->socketpass_mappings);
spin_lock_init(&bedata->pvcallss_lock);
- dev_set_drvdata(&dev->dev, bedata);
- pvcalls_front_dev = dev;
xenbus_switch_state(dev, XenbusStateInitialised);