#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <infiniband/verbs.h>
#include <rdma/rdma_cma.h>
#include <mpi.h>
#include "iflist.h"


#define ERROR(s) printf("ERROR %s: %s\n", (s), strerror(errno))


struct rdma_event_channel* rdma_chan;
struct rdma_cm_id* rdma_id;

int mpi_rank;
int mpi_size;


static void rdma_init(void)
{
    int i;

    rdma_chan = rdma_create_event_channel();
    if(!rdma_chan) {
        ERROR("rdma_create_event_channel()");
        exit(-1);
    }

    /* Just use the first interface that works */
    for(i = 0; i < if_num_addrs; i++) {
#if 0
        char buf[48] = {0};
        inet_ntop(AF_INET, &if_addrs[i].sin_addr, buf, 48);
        printf("trying addr %d %s\n", i, buf);
#endif
        if(rdma_create_id(rdma_chan, &rdma_id, NULL, RDMA_PS_UDP)) {
            ERROR("rdma_create_id()");
            exit(-1);
        }

        if(rdma_bind_addr(rdma_id, (struct sockaddr*)&if_addrs[i])) {
            rdma_destroy_id(rdma_id);

            if(errno == ENOENT) {
                continue;
            }

            ERROR("rdma_bind_addr()");
            exit(-1);
        }

        /* wtf is going on, rdma_bind_addr() is succeeding yet verbs is NULL */
        if(rdma_id->verbs) {
            break;
        }
    }

    if(mpi_rank == 0) {
        struct ibv_device_attr dev_attr = {{0}, 0};

        if(ibv_query_device(rdma_id->verbs, &dev_attr)) {
            printf("ERROR could not query device\n");
            exit(-4);
        }

        printf("max_mcast_grp %d\nmax_mcast_qp_attach %d\nmax_total_mcast_qp_attach %d\n",
                dev_attr.max_mcast_grp, dev_attr.max_mcast_qp_attach,
                dev_attr.max_total_mcast_qp_attach);
    }
}


static int poll_rdma(void)
{
    struct rdma_cm_event* event;
    struct timeval tv = {0};
    fd_set fds;
    int ret = -1;

    FD_ZERO(&fds);
    FD_SET(rdma_chan->fd, &fds);

    select(rdma_chan->fd + 1, &fds, NULL, NULL, &tv);
    if(!FD_ISSET(rdma_chan->fd, &fds)) {
        return 0;
    }

    if(rdma_get_cm_event(rdma_chan, &event)) {
        printf("ERROR could not get event\n");
        exit(-4);
    }

    switch(event->event) {
    case RDMA_CM_EVENT_MULTICAST_JOIN:
        if(mpi_rank == 0) {
            struct rdma_ud_param* param = &event->param.ud;
            char buf[48] = {0};

            inet_ntop(AF_INET6, param->ah_attr.grh.dgid.raw, buf, 48);
            printf("%d joined group %s\n", mpi_rank, buf);

            /* Bcast the addr to everyone else */
            MPI_Bcast(param->ah_attr.grh.dgid.raw,
                    sizeof(struct in6_addr), MPI_BYTE, 0, MPI_COMM_WORLD);
            MPI_Barrier(MPI_COMM_WORLD);
        }

        ret = 1;
        break;
    case RDMA_CM_EVENT_ADDR_ERROR:
    case RDMA_CM_EVENT_ROUTE_ERROR:
    case RDMA_CM_EVENT_MULTICAST_ERROR:
        printf("ERROR event %d, status %d %s, forcing job to hang\n",
                event->event, event->status, strerror(errno));
        fflush(stdout);
        while(1) { sleep(300); }
        break;
    default:
        printf("Unhandled event %d\n", event->event);
        break;
    }

    rdma_ack_cm_event(event);
    return ret;
}


int join_multicast(void)
{
    struct sockaddr_in6 addr = {0};
    int rc;

    addr.sin6_family = AF_INET6;
    addr.sin6_flowinfo = 0;
    addr.sin6_port = 0;

    if(mpi_rank == 0) {
        /* Root joins the group, then bcasts the addr when joined */
        addr.sin6_addr = in6addr_any;
    } else {
        char buf[48];

        /* Everyone else waits on the bcast, then joins */
        MPI_Bcast(&addr.sin6_addr,
                sizeof(struct in6_addr), MPI_UNSIGNED_CHAR, 0, MPI_COMM_WORLD);
        MPI_Barrier(MPI_COMM_WORLD);

        inet_ntop(AF_INET6, &addr.sin6_addr, buf, 48);
        /*printf("%d got addr %s\n", mpi_rank, buf);*/
    }

    rc = rdma_join_multicast(rdma_id, (struct sockaddr*)&addr, NULL);
    if(rc) {
        printf("%d ERROR rdma_join_multicast(): %d %s\nforcing job to hang!\n",
                mpi_rank, errno, strerror(errno));
        fflush(stdout);
        while(1) { sleep(300); }
        return -1;
    }

    /* poll_rdma() does the root side of the MPI bcast */
    while(poll_rdma() <= 0);

    return 0;
}


int main(int argc, char** argv)
{
    int i;

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
    MPI_Comm_size(MPI_COMM_WORLD, &mpi_size);

    load_ifs();
	rdma_init();

    for(i = 0; ; i++) {
        if(mpi_rank == 0) {
            printf("group %d\n", i);
            fflush(stdout);
        }

        join_multicast();
        //usleep(500000);
    }

    unload_ifs();
    MPI_Finalize();
	return 0;
}

