Hi,

There's a memory leak in DTLS code when re-negotiating already established
session. valgrind output:

==5475== HEAP SUMMARY:
==5475==     in use at exit: 2,285 bytes in 17 blocks
==5475==   total heap usage: 7,973 allocs, 7,956 frees, 789,213 bytes
allocated
==5475==
==5475== 432 (168 direct, 264 indirect) bytes in 1 blocks are definitely
lost in loss record 15 of 17
==5475==    at 0x4C29F90: malloc (in
/usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
==5475==    by 0x4518EC: default_malloc_ex (mem.c:79)
==5475==    by 0x451F7B: CRYPTO_malloc (mem.c:308)
==5475==    by 0x47C948: EVP_CIPHER_CTX_new (evp_enc.c:90)
==5475==    by 0x446C80: tls1_change_cipher_state (t1_enc.c:419)
==5475==    by 0x403E6D: dtls1_accept (d1_srvr.c:738)
==5475==    by 0x415315: SSL_do_handshake (ssl_lib.c:2605)
==5475==    by 0x402226: handshakeIteration (dtlstest2.c:146)
==5475==    by 0x402816: main (dtlstest2.c:278)
==5475==
==5475== 861 (48 direct, 813 indirect) bytes in 1 blocks are definitely
lost in loss record 17 of 17
==5475==    at 0x4C29F90: malloc (in
/usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
==5475==    by 0x4518EC: default_malloc_ex (mem.c:79)
==5475==    by 0x451F7B: CRYPTO_malloc (mem.c:308)
==5475==    by 0x47BEA9: EVP_MD_CTX_create (digest.c:131)
==5475==    by 0x446CD4: tls1_change_cipher_state (t1_enc.c:424)
==5475==    by 0x403E6D: dtls1_accept (d1_srvr.c:738)
==5475==    by 0x415315: SSL_do_handshake (ssl_lib.c:2605)
==5475==    by 0x402226: handshakeIteration (dtlstest2.c:146)
==5475==    by 0x402816: main (dtlstest2.c:278)
==5475==
==5475== LEAK SUMMARY:
==5475==    definitely lost: 216 bytes in 2 blocks
==5475==    indirectly lost: 1,077 bytes in 9 blocks
==5475==      possibly lost: 0 bytes in 0 blocks
==5475==    still reachable: 992 bytes in 6 blocks
==5475==         suppressed: 0 bytes in 0 blocks


Test application is in the attachment. Just run it under valgrind
--leak-check-full

OpenSSL version is 1.0.1 latest revision from the official repo. Tested
under Linux x86_64.

---
Thanks,
Dmitry Sobinov

#include <stdint.h>
#include <assert.h>
#include <time.h>
#include <stdio.h>
#include <pthread.h>
#include <stdbool.h>

#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/x509.h>


typedef struct DtlsIdentity_
{
    EVP_PKEY* key;
    X509* certificate;
} DtlsIdentity;

/**
* Helper functions (defined in the bottom of the file)
*/
unsigned long idFunction();
void opensslLockingFunc(int mode, int n,
    const char* /*file*/, int /*line*/);
void opensslInit();
void opensslCleanup();
EVP_PKEY* generateRsaKeyPair();
X509* generateCertificate(EVP_PKEY* pkey, const char* commonName);
DtlsIdentity generateIdentity();


enum DtlsRole
{
    DTLS_CLIENT = 0,
    DTLS_SERVER = 1
};

typedef struct PeerContext_
{
    SSL* ssl;
    SSL_CTX* ctx;
    BIO* inBio;
    BIO* outBio;
    X509* certificate;
    EVP_PKEY* key;
    enum DtlsRole role;
    char label[20];

    bool handshakeCompleted;
    bool activeRenegotiation;

    int negotiationsCount;
} PeerContext;

void sslInfoCallbackInternal(PeerContext*, const SSL* s, int where, int ret);
void sslInfoCallback(const SSL* s, int where, int ret)
{
    sslInfoCallbackInternal((PeerContext*)SSL_get_app_data(s), s, where, ret);
}
int sslVerifyCallback(int ok, X509_STORE_CTX* store)
{
    // we don't verify certificate here for simplicity
    return 1;
}

void initContext(PeerContext* ctx)
{
    // set labels to distinguish client/server for logging:
    if (ctx->role == DTLS_CLIENT)
        sprintf(ctx->label, "[C] ");
    else
        sprintf(ctx->label, "[S] ");

    ctx->handshakeCompleted = false;
    ctx->activeRenegotiation = false;
    ctx->negotiationsCount = 0;

    // generate new certificate and private key:
    DtlsIdentity identity = generateIdentity();

    ctx->ctx = (ctx->role == DTLS_CLIENT) ?
        SSL_CTX_new(DTLSv1_client_method()) :
        SSL_CTX_new(DTLSv1_server_method());

    assert(ctx->ctx);

    SSL_CTX_use_certificate(ctx->ctx, identity.certificate);
    SSL_CTX_use_PrivateKey(ctx->ctx, identity.key);
    EVP_PKEY_free(identity.key);
    X509_free(identity.certificate);

    SSL_CTX_set_info_callback(ctx->ctx, &sslInfoCallback);

    SSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
        &sslVerifyCallback);
    SSL_CTX_set_verify_depth(ctx->ctx, 1);
    SSL_CTX_set_cipher_list(ctx->ctx, "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

    SSL_CTX_set_read_ahead(ctx->ctx, 1);

    // "bad decompression" error fix for some linuxes:
    SSL_CTX_set_options(ctx->ctx, SSL_OP_NO_COMPRESSION);
    // disable tickets for simplicity:
    SSL_CTX_set_options(ctx->ctx, SSL_OP_NO_TICKET);

    ctx->ssl = SSL_new(ctx->ctx);
    assert(ctx->ssl);

    ctx->inBio = BIO_new(BIO_s_mem());
    ctx->outBio = BIO_new(BIO_s_mem());

    SSL_set_app_data(ctx->ssl, ctx);

    if (ctx->role == DTLS_CLIENT)
        SSL_set_connect_state(ctx->ssl);
    else
        SSL_set_accept_state(ctx->ssl);

    SSL_set_bio(ctx->ssl, ctx->inBio, ctx->outBio);  //< the SSL object owns the bio now
}

void renegotiate(PeerContext* ctx)
{
    printf("%s <<<<Renegotiation requested>>>>\n", ctx->label);
    assert(ctx->handshakeCompleted);
    assert(!ctx->activeRenegotiation);

    ctx->activeRenegotiation = true;
    (void)BIO_reset(ctx->inBio);
    (void)BIO_reset(ctx->outBio);

    SSL_renegotiate(ctx->ssl);
}

bool handshakeIteration(PeerContext* ctx, uint8_t** dataToSend, size_t* len, int* timeoutMs)
{
    int wantRead = false;

    // we don't actually read data, but need this for SSL_read:
    uint8_t buf[4096];

    // SSL_read after initial negotiation, SSL_do_handshake on client side
    // when renegotiation requested
    int res = (ctx->handshakeCompleted && !ctx->activeRenegotiation) ?
        SSL_read(ctx->ssl, buf, sizeof(buf)) : SSL_do_handshake(ctx->ssl);

    // get pointer to data written by handshake
    *len = BIO_get_mem_data(ctx->outBio, dataToSend);

    int err = SSL_get_error(ctx->ssl, res);
    struct timeval timeout;

    // check if remote side requested renegotiation
    if (!ctx->activeRenegotiation && ctx->handshakeCompleted && SSL_renegotiate_pending(ctx->ssl) == 1)
    {
        printf("%s Remote renegotiation detected\n", ctx->label);
        ctx->activeRenegotiation = true;
    }

    // check if renegotiation finished
    bool renegotiationFinished = ctx->activeRenegotiation && SSL_renegotiate_pending(ctx->ssl) == 0;

    // handle handshake errors
    switch (err)
    {
    case SSL_ERROR_NONE:
        if (!ctx->handshakeCompleted || renegotiationFinished)
        {
            ctx->handshakeCompleted = true;
            ctx->activeRenegotiation = false;
            ctx->negotiationsCount++;
        }
        break;

    case SSL_ERROR_WANT_READ:
        if (renegotiationFinished)
        {
            ctx->activeRenegotiation = false;
            ctx->negotiationsCount++;
        }
        else if (DTLSv1_get_timeout(ctx->ssl, &timeout))
        {
            *timeoutMs = timeout.tv_sec * 1000 + timeout.tv_usec / 1000;
            wantRead = true;
            printf("%s WANT_READ with timeout %d\n", ctx->label, *timeoutMs);
        }
        break;

    default:
        printf("Unexpected error while processing DTLS: %d\n", err);
        assert(false);
    }

    return wantRead;
}

void handleIncomingData(PeerContext* ctx, uint8_t* data, size_t size)
{
    printf("%s INCOMING DATA of size %d\n", ctx->label, size);

    (void)BIO_reset(ctx->inBio);
    (void)BIO_reset(ctx->outBio);
    BIO_write(ctx->inBio, data, size);
}

void sslInfoCallbackInternal(PeerContext* ctx, const SSL* s, int where, int ret)
{
    char method[100];
    int w = where & ~SSL_ST_MASK;

    if (w & SSL_ST_CONNECT)
        sprintf(method, "SSL_connect");
    else if (w & SSL_ST_ACCEPT)
        sprintf(method, "SSL_accept");

    if (where & SSL_CB_LOOP)
    {
        printf("%s %s: %s\n", ctx->label, method, SSL_state_string_long(s));
    }
    else if (where & SSL_CB_ALERT)
    {
        const char* direction = (where & SSL_CB_READ) ? "read" : "write";
        printf("%s SSL3 alert %s: %s : %s \n", ctx->label, direction,
            SSL_alert_type_string_long(ret),
            SSL_alert_desc_string_long(ret));
    }
    else if (where & SSL_CB_EXIT)
    {
        if (ret == 0)
        {
            printf("%s %s failed in %s \n", ctx->label, method,
                SSL_state_string_long(s));
        }
        else if (ret < 0)
        {
            printf("%s %s failed in %s \n", ctx->label, method,
                SSL_state_string_long(s));
        }
    }
}


int main()
{
    opensslInit();

    PeerContext clientCtx;
    clientCtx.role = DTLS_CLIENT;
    PeerContext serverCtx;
    serverCtx.role = DTLS_SERVER;

    initContext(&clientCtx);
    initContext(&serverCtx);

    uint8_t* data;
    size_t len;

    bool clientWantRead = false;
    bool serverWantRead = false;
    int timeoutMsClient = 0;
    int timeoutMsServer = 0;

    // starting to "listen" on server:
    handshakeIteration(&serverCtx, &data, &len, &timeoutMsServer);
    assert(len == 0);

    // initial negotiation:
    while (1)
    {
        handshakeIteration(&clientCtx, &data, &len, &timeoutMsClient);
        if (len)
            handleIncomingData(&serverCtx, data, len);

        if (clientCtx.handshakeCompleted)
            break;

        handshakeIteration(&serverCtx, &data, &len, &timeoutMsServer);
        if (len)
            handleIncomingData(&clientCtx, data, len);
    }

    assert(clientCtx.negotiationsCount == 1);
    printf("======== Renegotiating ========\n");
    renegotiate(&clientCtx);

    int clientPacketCounter = 0;

    // renegotiation loop:
    while (1)
    {
        clientWantRead = handshakeIteration(&clientCtx, &data, &len, &timeoutMsClient);
        if (len)
        {
            clientPacketCounter++;

            printf("Client has some data to send to the server. Size %d\n", len);

            // 
            if (clientPacketCounter != 2)
                handleIncomingData(&serverCtx, data, len);
            else
                printf("Intentionally dropping the packet and waiting\n");

            if (clientPacketCounter > 5)
            {
                printf("Error: too much packets sent!\n");
                goto end;
            }
        }

        // one renegotiation is enough (1 - initial negotiation, 2 - 1st renegotioation)
        if (clientCtx.negotiationsCount == 2)
            break;

        serverWantRead = handshakeIteration(&serverCtx, &data, &len, &timeoutMsServer);
        if (len)
        {
            printf("Server has some data to send to the client of size %d\n", len);

            handleIncomingData(&clientCtx, data, len);

            // client read request satisfied, no need to start timer at the end of the iteration:
            clientWantRead = false;
        }

        if (clientWantRead || serverWantRead)
        {
            int timeout = timeoutMsClient < timeoutMsServer ? timeoutMsClient : timeoutMsServer;
            printf("Waiting for %d ms for client to generate new flight\n", timeout);
            struct timespec rem;
            struct timespec req;
            
            req.tv_sec = timeout / 1000;
            req.tv_nsec = (timeout % 1000) * 1000000;
            nanosleep(&req, &rem);
            printf("Waiting is over\n");
        }
    }

    printf("Renegotiated successfully!\n");

end:

    SSL_shutdown(clientCtx.ssl);
    SSL_free(clientCtx.ssl);
    SSL_CTX_free(clientCtx.ctx);

    SSL_shutdown(serverCtx.ssl);
    SSL_free(serverCtx.ssl);
    SSL_CTX_free(serverCtx.ctx);

    opensslCleanup();

    return 0;
}

// Helper functions implementation



static pthread_mutex_t* mutex_buf = NULL;

unsigned long idFunction()
{
    return (unsigned long)pthread_self();
}

void opensslLockingFunc(int mode, int n,
    const char* file, int line)
{
    if (mode & CRYPTO_LOCK)
        pthread_mutex_lock(&mutex_buf[n]);
    else
        pthread_mutex_unlock(&mutex_buf[n]);
}


void opensslInit()
{
    SSL_library_init();
    SSL_load_error_strings();
    OpenSSL_add_all_algorithms();

    mutex_buf = (pthread_mutex_t*)malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
    int i;
    for (i = 0; i < CRYPTO_num_locks(); i++)
        pthread_mutex_init(&mutex_buf[i], NULL);

    CRYPTO_set_locking_callback(&opensslLockingFunc);
    CRYPTO_set_id_callback(&idFunction);
}

void opensslCleanup()
{
    CRYPTO_set_id_callback(0);
    CRYPTO_set_locking_callback(0);

    int i;
    for (i = 0; i < CRYPTO_num_locks(); i++)
        pthread_mutex_destroy(&mutex_buf[i]);

    free(mutex_buf);

    ERR_free_strings();
    ERR_remove_state(0);
    EVP_cleanup();
    CRYPTO_cleanup_all_ex_data();
}


const int gKeyLength = 1024;

// number of random bits for certificate serial number
const int gRandomBitsNum = 64;

// one year certificate validity
const int gCertificateLifetime = 60 * 60 * 24 * 365;

// to compensate for slightly incorrect system clocks
const int gCertificateValidationWindow = -60 * 60 * 24;

EVP_PKEY* generateRsaKeyPair()
{
    EVP_PKEY* pkey = EVP_PKEY_new();
    BIGNUM* exponent = BN_new();
    RSA* rsa = RSA_new();
    if (!pkey || !exponent || !rsa ||
        !BN_set_word(exponent, 0x10001) ||
        !RSA_generate_key_ex(rsa, gKeyLength, exponent, NULL) ||
        !EVP_PKEY_assign_RSA(pkey, rsa))
    {
        EVP_PKEY_free(pkey);
        BN_free(exponent);
        RSA_free(rsa);
        return NULL;
    }

    BN_free(exponent);
    return pkey;
}

X509* generateCertificate(EVP_PKEY* pkey, const char* commonName)
{
    X509* x509 = NULL;
    BIGNUM* serialNumber = NULL;
    X509_NAME* name = NULL;

    if ((x509 = X509_new()) == NULL)
        goto error;

    if (!X509_set_pubkey(x509, pkey))
        goto error;

    ASN1_INTEGER* asn1SerialNumber;
    if ((serialNumber = BN_new()) == NULL ||
        !BN_pseudo_rand(serialNumber, gRandomBitsNum, 0, 0) ||
        (asn1SerialNumber = X509_get_serialNumber(x509)) == NULL ||
        !BN_to_ASN1_INTEGER(serialNumber, asn1SerialNumber))
        goto error;

    if (!X509_set_version(x509, 0L))
        goto error;

    if ((name = X509_NAME_new()) == NULL ||
        !X509_NAME_add_entry_by_NID(name, NID_commonName, MBSTRING_UTF8,
        (unsigned char*)commonName, -1, -1, 0) ||
        !X509_set_subject_name(x509, name) ||
        !X509_set_issuer_name(x509, name))
        goto error;

    if (!X509_gmtime_adj(X509_get_notBefore(x509), gCertificateValidationWindow) ||
        !X509_gmtime_adj(X509_get_notAfter(x509), gCertificateLifetime))
        goto error;

    if (!X509_sign(x509, pkey, EVP_sha256()))
        goto error;

    BN_free(serialNumber);
    X509_NAME_free(name);
    return x509;

error:
    BN_free(serialNumber);
    X509_NAME_free(name);
    X509_free(x509);
    return NULL;
}

DtlsIdentity generateIdentity()
{
    DtlsIdentity id;
    id.key = generateRsaKeyPair();
    id.certificate = generateCertificate(id.key, "TestCompany Inc");
    return id;
}


Reply via email to