#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <time.h>
#include <limits.h>

//#define NULL ((void *) 0)
typedef uint64_t uint64;
typedef int64_t int64;
#define BITS_PER_BITMAPWORD 64
typedef uint64 bitmapword;      /* must be an unsigned type */
typedef int64 signedbitmapword; /* must be the matching signed type */

#define WORDNUM(x)  ((x) / BITS_PER_BITMAPWORD)
#define BITNUM(x)   ((x) % BITS_PER_BITMAPWORD)

typedef struct Bitmapset
{
    int         nwords;         /* number of words in array */
    bitmapword  words[];    /* really [nwords] */
} Bitmapset;

static inline int
bmw_rightmost_one_pos(uint64 word)
{
    return __builtin_ctzll(word);
}

// 1. Original version
int
bms_next_member(const Bitmapset *a, int prevbit)
{
    int         nwords;
    bitmapword  mask;

    //Assert(bms_is_valid_set(a));

    if (a == NULL)
        return -2;
    nwords = a->nwords;
    prevbit++;
    mask = (~(bitmapword) 0) << BITNUM(prevbit);
    for (int wordnum = WORDNUM(prevbit); wordnum < nwords; wordnum++)
    {
        bitmapword  w = a->words[wordnum];

        /* ignore bits before prevbit */
        w &= mask;

        if (w != 0)
        {
            int         result;

            result = wordnum * BITS_PER_BITMAPWORD;
            result += bmw_rightmost_one_pos(w);
            return result;
        }

        /* in subsequent words, consider all bits */
        mask = (~(bitmapword) 0);
    }
    return -2;
}

// 2. Fast version (size_t usage)
int
bms_next_member_fast(const Bitmapset *a, int prevbit)
{
    uint64      currbit;
    size_t      nwords;
    bitmapword  mask;

    if (a == NULL)
        return -2;
    nwords = (size_t) a->nwords;
    currbit = (uint64) prevbit + 1;
    mask = (~(bitmapword) 0) << BITNUM(currbit);
    for (size_t wordnum = WORDNUM(currbit); wordnum < nwords; wordnum++)
    {
        bitmapword  w = a->words[wordnum];

        /* ignore bits before currbit */
        w &= mask;

        if (w != 0)
        {
            int         result;

            result = (int) wordnum * BITS_PER_BITMAPWORD;
            result += bmw_rightmost_one_pos(w);
            return result;
        }

        /* in subsequent words, consider all bits */
        mask = (~(bitmapword) 0);
    }
    return -2;
}

// 3. Original version + INT32_MAX check + 64bit
int
bms_next_member_2(const Bitmapset *a, int prevbit)
{
    size_t         nwords;
    bitmapword  mask;

    if (a == NULL || prevbit == INT32_MAX)
        return -2;
    nwords = (size_t) a->nwords;
    prevbit++;
    mask = (~(bitmapword) 0) << BITNUM(prevbit);
    for (size_t wordnum = WORDNUM(prevbit); wordnum < nwords; wordnum++)
    {
        bitmapword  w = a->words[wordnum];

        /* ignore bits before prevbit */
        w &= mask;

        if (w != 0)
        {
            int         result;

            result = (int)wordnum * BITS_PER_BITMAPWORD;
            result += bmw_rightmost_one_pos(w);
            return result;
        }

        /* in subsequent words, consider all bits */
        mask = (~(bitmapword) 0);
    }
    return -2;
}

// 4. Pull up first iteration
int bms_next_member_pullup(const Bitmapset *a, int prevbit) {
if (a == NULL || prevbit == INT_MAX)
        return -2;

    uint64      currbit = (uint64) prevbit + 1;
    int         wordnum = WORDNUM(currbit);
    int         nwords = a->nwords;

    if (wordnum >= nwords)
        return -2;

    /* Handle first word with mask */
    const bitmapword *p = &a->words[wordnum];
    bitmapword  w = (*p) & ((~(bitmapword) 0) << BITNUM(currbit));

    if (w != 0)
        return (wordnum * BITS_PER_BITMAPWORD) + bmw_rightmost_one_pos(w);

    /* The "Tight" Pointer Scan */
    const bitmapword *end = &a->words[nwords];
    for (p++; p < end; p++)
    {
        if (*p != 0)
        {
            wordnum = p - a->words; // Pointer arithmetic to get index
            return (wordnum * BITS_PER_BITMAPWORD) + bmw_rightmost_one_pos(*p);
        }
    }

    return -2;
}


double get_time() {
    struct timespec ts;
    clock_gettime(CLOCK_MONOTONIC, &ts);
    return ts.tv_sec + ts.tv_nsec * 1e-9;
}

int main() {
    int words_to_alloc = 20000; // Large set to bypass CPU cache slightly
    Bitmapset *bms = malloc(sizeof(Bitmapset) + words_to_alloc * sizeof(bitmapword));
    bms->nwords = words_to_alloc;
    memset(bms->words, 0, words_to_alloc * sizeof(bitmapword));

    /* Set a bit far into the set to force a long scan */
    int target_bit = (words_to_alloc - 1) * 64 + 10;
    bms->words[words_to_alloc - 1] |= (1ULL << 10);

    int iterations = 100000;
    volatile int sink;

    printf("Benchmarking %d iterations...\n\n", iterations);

    // Test Original
    double start = get_time();
    for (int i = 0; i < iterations; i++) sink = bms_next_member(bms, 0);
    printf("Original:  %.5f seconds\n", get_time() - start);

    // Test Fast
    start = get_time();
    for (int i = 0; i < iterations; i++) sink = bms_next_member_fast(bms, 0);
    printf("Fast:      %.5f seconds\n", get_time() - start);

    // Test Original2
    start = get_time();
    for (int i = 0; i < iterations; i++) sink = bms_next_member_2(bms, 0);
    printf("Original2:      %.5f seconds\n", get_time() - start);

    // Pull up first iteration
    start = get_time();
    for (int i = 0; i < iterations; i++) sink = bms_next_member_pullup(bms, 0);
    printf("PullUp: %.5f seconds\n", get_time() - start);

    free(bms);
    return 0;
}