Hello

While testing renegotiations for DTLS-SRTP, found a crash on Windows.
OpenSSL version is 1.0.1e, also tested on the latest 1.0.1 snapshot. There
were 2 possible stack traces:

  AddLiveService.dll!EVP_MD_size(const env_md_st * md) Line 273 C
> AddLiveService.dll!dtls1_do_write(ssl_st * s, int type) Line 275 C
  AddLiveService.dll!dtls1_retransmit_message(ssl_st * s, unsigned short
seq, unsigned long frag_off, int * found) Line 1293 C
  AddLiveService.dll!dtls1_retransmit_buffered_messages(ssl_st * s) Line
1145 C
  AddLiveService.dll!dtls1_handle_timeout(ssl_st * s) Line 450 C
  AddLiveService.dll!dtls1_read_bytes(ssl_st * s, int type, unsigned char *
buf, int len, int peek) Line 832 C
  AddLiveService.dll!dtls1_get_message_fragment(ssl_st * s, int st1, int
stn, long max, int * ok) Line 789 C
  AddLiveService.dll!dtls1_get_message(ssl_st * s, int st1, int stn, int
mt, long max, int * ok) Line 436 C
  AddLiveService.dll!ssl3_get_new_session_ticket(ssl_st * s) Line 2046 C
  AddLiveService.dll!dtls1_connect(ssl_st * s) Line 631 C
  AddLiveService.dll!SSL_do_handshake(ssl_st * s) Line 2562 C

and

  msvcr120d.dll!memcpy(unsigned char * dst, unsigned char * src, unsigned
long count) Line 188 Unknown
> dtls_test.exe!dtls1_get_message_fragment(ssl_st * s, int st1, int stn,
long max, int * ok) Line 789 C
  dtls_test.exe!dtls1_get_message(ssl_st * s, int st1, int stn, int mt,
long max, int * ok) Line 436 C
  dtls_test.exe!ssl3_get_new_session_ticket(ssl_st * s) Line 2046 C
  dtls_test.exe!dtls1_connect(ssl_st * s) Line 631 C
  dtls_test.exe!SSL_do_handshake(ssl_st * s) Line 2562 C

Both are segfaults (access violations). On linux rehandshake doesn't finish
at all (failure after 1-2 minutes on timeout).

You can find sample c++11 source file to reproduce this issue. In-memory
BIO pair is used, client and server in the same process. When no flights
are dropped, everything is fine.

The sample can be compiled by MSVC 2013 on Windows and g++ 4.7+ (g++ -o
dtlstest main.cpp -std=c++11 -lssl -lcrypto -lpthread -g) or clang 3.2+.


---
Dmitry Sobinov
AddLive.com
Live video and voice for your application

#include <iostream>
#include <string>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <future>
#include <memory>
#include <vector>
#include <deque>
#include <chrono>
#include <algorithm>
#include <functional>
#include <stdint.h>
#include <assert.h>

#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/x509.h>

// Can be built in MSVC 2013,
// gcc:
// g++ -o dtlstest dtlstest.cpp -std=c++11 -lssl -lcrypto -lpthread -g
// clang with libc++:
// clang++ -o dtlstest dtlstest.cpp -std=c++11 -lssl -lcrypto -lpthread -stdlib=libc++ -lc++abi -g


std::chrono::steady_clock::time_point logStartingTime = std::chrono::steady_clock::now();

#define MLOG_D(x) std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - logStartingTime).count() << _label << x << std::endl;
#define LOG_E(x) std::cout << "[ERROR] " << x << std::endl;
#define MLOG_E(x) LOG_E(_label << x)

#ifdef X509_NAME
#undef X509_NAME // disable macro from wincrypt.h (included from dtls1.h/winsock.h)
#endif


struct DtlsIdentity
{
    EVP_PKEY* key;
    X509* certificate;
};

namespace
{
    /**
     * Helper functions
     */

    unsigned long idFunction();

    void opensslLockingFunc(int mode, int n,
        const char* /*file*/, int /*line*/);

    void opensslInit();

    void opensslCleanup();

    EVP_PKEY* generateRsaKeyPair();

    X509* generateCertificate(EVP_PKEY* pkey, const char* commonName);

    DtlsIdentity generateIdentity();

    void logOpenSslErrors(const std::string& prefix);
}

typedef std::function<void()> DispatcherTask;

/**
 * Helper class to serialize all requests and data transmissions in
 * one separate thread (implementation of ActiveObject pattern).
 */
class AsyncDispatcher
{
    struct TimedTask
    {
        DispatcherTask task;
        std::chrono::steady_clock::time_point timeToFire;
        int id;
    };

public:

    AsyncDispatcher()
    {
        _thread = std::thread([this](){ run(); });
    }

    int push(const DispatcherTask& task,
        std::chrono::milliseconds delay = std::chrono::milliseconds(0))
    {
        std::unique_lock<std::mutex> lk(_queueMutex);
        int id = _idCounter++;
        _queue.push_back({ task, std::chrono::steady_clock::now() + delay, id });
        std::stable_sort(_queue.begin(), _queue.end(),
            [](const TimedTask& t1, const TimedTask& t2) -> bool { return t1.timeToFire < t2.timeToFire; });
        lk.unlock();
        _condVar.notify_one();
        return id;
    }

    void stop()
    {
        std::unique_lock<std::mutex> lk(_queueMutex);
        _active = false;
        _queue.clear();
        lk.unlock();
        _condVar.notify_one();

        _thread.join();
    }

    void cancelTimedTask(int id)
    {
        std::unique_lock<std::mutex> lk(_queueMutex);
        _queue.erase(std::remove_if(_queue.begin(), _queue.end(),
            [=](const TimedTask& elem){ return elem.id == id; }),
            _queue.end());
        lk.unlock();
        _condVar.notify_one();
    }

private:

    bool waitAndPop(TimedTask& poppedValue)
    {
        std::unique_lock<std::mutex> lk(_queueMutex);

        while (true)
        {
            if (_queue.empty())
            {
                // queue empty, new handlers are not allowed to add => exiting
                if (!_active)
                    return false;

                // wait for pushed element if no elements to process right away
                _condVar.wait(lk);
                continue;
            }

            auto nearestExpireTime = _queue.front().timeToFire;
            if (nearestExpireTime >= std::chrono::steady_clock::now())
            {
                _condVar.wait_until(lk, nearestExpireTime);
                continue;
            }

            // have some expired callbacks
            break;
        }

        poppedValue = _queue.front();
        _queue.pop_front();

        return true;
    }

    void run()
    {
        bool hasPendingWork = true;
        
        while (true)
        {
            TimedTask rec;
            if (!waitAndPop(rec))
                return;
            rec.task();
        } while (hasPendingWork);
    }

    std::thread _thread;
    std::mutex _queueMutex;
    std::condition_variable _condVar;
    std::deque<TimedTask> _queue;

    bool _active = true;
    int _idCounter = 0;
};

enum DtlsRole
{
    DTLS_CLIENT = 0,
    DTLS_SERVER = 1
};

typedef std::function<void(const std::vector<uint8_t>&)> DtlsSendFunc;
typedef std::function<void(bool)> DtlsConnectResultHandler;

class DtlsSrtpTransport
{
public:

    DtlsSrtpTransport(DtlsRole role, const std::string& label,
        AsyncDispatcher& dispatcher) :
        _role(role),
        _label(label),
        _dispatcher(dispatcher)
    {
    }

    ~DtlsSrtpTransport()
    {
        stopInternal();
    }

    void handleIncomingData(const std::vector<uint8_t>& data)
    {
        // handle data in dispatcher thread
        _dispatcher.push([this, data]()
        {
            MLOG_D("INCOMING DATA of size " << data.size());

            (void)BIO_reset(_inBio);
            (void)BIO_reset(_outBio);
            ::BIO_write(_inBio, &data[0], data.size());
            handshakeIteration();
        });
    }

    void setResultHandler(const DtlsConnectResultHandler& h)
    {
        _resultHandler = h;
    }

    void setSendFunction(const DtlsSendFunc& sendFunc)
    {
        _sendFunc = sendFunc;
    }

    void setIdentity(const DtlsIdentity& identity)
    {
        _pkey = identity.key;
        _certificate = identity.certificate;
    }

    void start()
    {
        // perform on dispatcher thread
        _dispatcher.push([this](){ startInternal(); });
    }

    void renegotiate()
    {
        // perform on dispatcher thread
        _dispatcher.push([this]()
        {
            MLOG_D("<<<<Renegotiation requested>>>>");
            assert(_handshakeCompleted);
            assert(!_activeRenegotiation);

            _activeRenegotiation = true;
            (void)BIO_reset(_inBio);
            (void)BIO_reset(_outBio);
            //SSL_renegotiate_abbreviated(_ssl);
            SSL_renegotiate(_ssl);
            handshakeIteration();
        });
    }

private:

    void startInternal()
    {
        MLOG_D("Starting DTLS-SRTP");
        _sslCtx = createSslContext();
        assert(_sslCtx);
        _ssl = ::SSL_new(_sslCtx);
        assert(_ssl);

        _inBio = BIO_new(BIO_s_mem());
        _outBio = BIO_new(BIO_s_mem());

        SSL_set_app_data(_ssl, this);

        if (_role == DTLS_CLIENT)
            ::SSL_set_connect_state(_ssl);
        else
            ::SSL_set_accept_state(_ssl);

        ::SSL_set_bio(_ssl, _inBio, _outBio);  //< the SSL object owns the bio now

        MLOG_D("DTLS context initialization finished");

        handshakeIteration();
    }

    SSL_CTX* createSslContext()
    {
        SSL_CTX *ctx = (_role == DTLS_CLIENT) ?
            ::SSL_CTX_new(DTLSv1_client_method()) :
            ::SSL_CTX_new(DTLSv1_server_method());

        assert(ctx);
        assert(_certificate);
        assert(_pkey);

        ::SSL_CTX_use_certificate(ctx, _certificate);
        ::SSL_CTX_use_PrivateKey(ctx, _pkey);

        if (_role == DTLS_SERVER)
        {
            SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
        }

        ::SSL_CTX_set_info_callback(ctx, &DtlsSrtpTransport::sslInfoCallback);

        ::SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
            &DtlsSrtpTransport::sslVerifyCallback);
        ::SSL_CTX_set_verify_depth(ctx, 1);
        ::SSL_CTX_set_cipher_list(ctx, "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

        //::SSL_CTX_set_tlsext_use_srtp(ctx, "SRTP_AES128_CM_SHA1_32:SRTP_AES128_CM_SHA1_80");
        SSL_CTX_set_read_ahead(ctx, 1);

        // "bad decompression" error fix for some linuxes:
        SSL_CTX_set_options(ctx, SSL_OP_NO_COMPRESSION);

        return ctx;
    }

    void stopInternal()
    {
        if (_ssl)
        {
            ::SSL_shutdown(_ssl);
            ::SSL_free(_ssl);
            _ssl = NULL;
        }
        if (_sslCtx)
        {
            ::SSL_CTX_free(_sslCtx);
            _sslCtx = NULL;
        }
    }

    void handshakeIteration()
    {
        // we don't actually read data, but need this for SSL_read:
        uint8_t buf[4096];

        // SSL_read after initial negotiation, SSL_do_handshake on client side
        // when renegotiation requested
        int res = (_handshakeCompleted && !_activeRenegotiation) ?
            ::SSL_read(_ssl, buf, sizeof(buf)) : ::SSL_do_handshake(_ssl);

        // get pointer to data written by handshake
        int outBioLen = 0;
        uint8_t *outBioData;
        outBioLen = BIO_get_mem_data(_outBio, &outBioData);

        int err = ::SSL_get_error(_ssl, res);
        struct timeval timeout;

        // check if remote side requested renegotiation
        if (!_activeRenegotiation && _handshakeCompleted && SSL_renegotiate_pending(_ssl) == 1)
        {
            MLOG_D("Remote renegotiation detected");
            _activeRenegotiation = true;
        }

        // check if renegotiation finished
        bool renegotiationFinished = _activeRenegotiation && SSL_renegotiate_pending(_ssl) == 0;

        // handle handshake errors
        switch (err)
        {
        case SSL_ERROR_NONE:
            if (!_handshakeCompleted || renegotiationFinished)
            {
                _handshakeCompleted = true;
                _activeRenegotiation = false;
                reportSuccess();
                _dispatcher.cancelTimedTask(_currentTimerId);
            }
            break;

        case SSL_ERROR_WANT_READ:
            if (renegotiationFinished)
            {
                _activeRenegotiation = false;
                _dispatcher.cancelTimedTask(_currentTimerId);
                reportSuccess();
            }
            else if (DTLSv1_get_timeout(_ssl, &timeout))
            {
                int delay = timeout.tv_sec * 1000 + timeout.tv_usec / 1000;
                MLOG_D("WANT_READ: setting new timer for " << delay << "ms");
                _currentTimerId = _dispatcher.push([this](){ receiveTimerExpired(); },
                    std::chrono::milliseconds(delay));
            }
            break;

        default:
            MLOG_E("Unexpected error while processing DTLS: " << err);
            logOpenSslErrors("SSL reading");
            reportFailure();
            _dispatcher.cancelTimedTask(_currentTimerId);
            assert(false);
            // don't write any data, just return:
            return;
        }

        if (outBioLen)
        {
            MLOG_D("Sending handshake data for DTLS; size " << outBioLen);
            assert(outBioLen < 1476); // do not exceed ethernet MTU
            std::vector<uint8_t> data(outBioData, outBioData + outBioLen);
            _sendFunc(data);
        }
    }

    void receiveTimerExpired()
    {
        MLOG_D("DTLS timer expired. Asking OpenSSL to repeat operations");

        (void)BIO_reset(_inBio);
        (void)BIO_reset(_outBio);

        handshakeIteration();
    }

    void reportFailure()
    {
        MLOG_E("Reporting failure");
        _resultHandler(false);
    }

    void reportSuccess()
    {
        MLOG_D("Reporting negotiation success");
        _resultHandler(true);
    }

    static int sslVerifyCallback(int ok, X509_STORE_CTX* store)
    {
        // we don't verify certificate here for simplicity
        return 1;
    }

    static void sslInfoCallback(const SSL* s, int where, int ret)
    {
        auto this_ = reinterpret_cast<DtlsSrtpTransport*>(SSL_get_app_data(s));
        this_->sslInfoCallbackInternal(s, where, ret);
    }

    void sslInfoCallbackInternal(const SSL* s, int where, int ret)
    {
        std::string method = "undefined";
        int w = where & ~SSL_ST_MASK;
        if (w & SSL_ST_CONNECT)
        {
            method = "SSL_connect";
        }
        else if (w & SSL_ST_ACCEPT)
        {
            method = "SSL_accept";
        }


        if (where & SSL_CB_LOOP)
        {
            MLOG_D(method << ": " << SSL_state_string_long(s));
        }
        else if (where & SSL_CB_ALERT)
        {
            const char* direction = (where & SSL_CB_READ) ? "read" : "write";
            MLOG_D("SSL3 alert " << direction
                << ":" << SSL_alert_type_string_long(ret)
                << ":" << SSL_alert_desc_string_long(ret));
        }
        else if (where & SSL_CB_EXIT)
        {
            if (ret == 0)
            {
                MLOG_D(method << " failed in " << SSL_state_string_long(s));
            }
            else if (ret < 0)
            {
                MLOG_D(method << " error in " << SSL_state_string_long(s));
            }
        }
    }


    SSL* _ssl = nullptr;
    SSL_CTX* _sslCtx = nullptr;
    BIO* _inBio = nullptr;
    BIO* _outBio = nullptr;
    X509* _certificate = nullptr;
    EVP_PKEY* _pkey = nullptr;

    bool _handshakeCompleted = false;
    DtlsConnectResultHandler _resultHandler;
    DtlsSendFunc _sendFunc;


    DtlsRole _role;
    std::string _label;

    bool _activeRenegotiation = false;
    int _currentTimerId = 0;

    AsyncDispatcher& _dispatcher;
};



int main()
{
    opensslInit();

    AsyncDispatcher dispatcher;
    
    DtlsSrtpTransport client(DTLS_CLIENT, " [C] ", dispatcher);
    DtlsSrtpTransport server(DTLS_SERVER, " [S] ", dispatcher);

    std::promise<bool> clientResultPromise;
    std::promise<bool> serverResultPromise;

    // on negotiation set promises so we can go on
    client.setResultHandler([&](bool result){ clientResultPromise.set_value(result); });
    server.setResultHandler([&](bool result){ serverResultPromise.set_value(result); });

    client.setIdentity(generateIdentity());
    server.setIdentity(generateIdentity());

    int clientCounter = 0;
    client.setSendFunction([&](const std::vector<uint8_t>& data)
    {
        // drop every 2nd packet
        if ((++clientCounter % 2) == 0)
            server.handleIncomingData(data);
    });

    int serverCounter = 0;
    server.setSendFunction([&](const std::vector<uint8_t>& data)
    {
        // drop every 2nd packet
        if ((++serverCounter % 2) == 0)
            client.handleIncomingData(data);
    });

    client.start();
    server.start();

    // block until get results in promises
    auto clientResult = clientResultPromise.get_future().get();
    auto serverResult = serverResultPromise.get_future().get();

    assert(clientResult);
    assert(serverResult);

    /// renegotiation

    // reset promises
    clientResultPromise = std::promise<bool>();
    serverResultPromise = std::promise<bool>();

    // ask for renegotiation
    client.renegotiate();

    // block until get results
    clientResult = clientResultPromise.get_future().get();
    serverResult = serverResultPromise.get_future().get();

    assert(clientResult);
    assert(serverResult);

    dispatcher.stop();
    opensslCleanup();
}

// Helper functions implementation

namespace
{
    std::vector<std::shared_ptr<std::mutex>> opensslMutexes;


    unsigned long idFunction()
    {
        return std::hash<std::thread::id>()(std::this_thread::get_id());
    }

    void opensslLockingFunc(int mode, int n,
        const char* /*file*/, int /*line*/)
    {
        if (mode & CRYPTO_LOCK)
            opensslMutexes[n]->lock();
        else
            opensslMutexes[n]->unlock();
    }


    void opensslInit()
    {
        ::SSL_library_init();
        ::SSL_load_error_strings();
        ::OpenSSL_add_all_algorithms();

        opensslMutexes.resize(::CRYPTO_num_locks());
        for (auto& mutex : opensslMutexes)
            mutex.reset(new std::mutex());
        ::CRYPTO_set_locking_callback(&opensslLockingFunc);
        ::CRYPTO_set_id_callback(&idFunction);
    }

    void opensslCleanup()
    {
        ::CRYPTO_set_id_callback(0);
        ::CRYPTO_set_locking_callback(0);
        ::ERR_free_strings();
        ::ERR_remove_state(0);
        ::EVP_cleanup();
        ::CRYPTO_cleanup_all_ex_data();
    }


    const int gKeyLength = 1024;

    // number of random bits for certificate serial number
    const int gRandomBitsNum = 64;

    // one year certificate validity
    const int gCertificateLifetime = 60 * 60 * 24 * 365;

    // to compensate for slightly incorrect system clocks
    const int gCertificateValidationWindow = -60 * 60 * 24;

    EVP_PKEY* generateRsaKeyPair()
    {
        EVP_PKEY* pkey = EVP_PKEY_new();
        BIGNUM* exponent = BN_new();
        RSA* rsa = RSA_new();
        if (!pkey || !exponent || !rsa ||
            !BN_set_word(exponent, 0x10001) ||
            !RSA_generate_key_ex(rsa, gKeyLength, exponent, NULL) ||
            !EVP_PKEY_assign_RSA(pkey, rsa))
        {
            EVP_PKEY_free(pkey);
            BN_free(exponent);
            RSA_free(rsa);
            return NULL;
        }

        BN_free(exponent);
        return pkey;
    }

    X509* generateCertificate(EVP_PKEY* pkey, const char* commonName)
    {
        X509* x509 = NULL;
        BIGNUM* serialNumber = NULL;
        X509_NAME* name = NULL;

        if ((x509 = X509_new()) == NULL)
            goto error;

        if (!X509_set_pubkey(x509, pkey))
            goto error;

        ASN1_INTEGER* asn1SerialNumber;
        if ((serialNumber = BN_new()) == NULL ||
            !BN_pseudo_rand(serialNumber, gRandomBitsNum, 0, 0) ||
            (asn1SerialNumber = X509_get_serialNumber(x509)) == NULL ||
            !BN_to_ASN1_INTEGER(serialNumber, asn1SerialNumber))
            goto error;

        if (!X509_set_version(x509, 0L))
            goto error;

        if ((name = X509_NAME_new()) == NULL ||
            !X509_NAME_add_entry_by_NID(name, NID_commonName, MBSTRING_UTF8,
            (unsigned char*)commonName, -1, -1, 0) ||
            !X509_set_subject_name(x509, name) ||
            !X509_set_issuer_name(x509, name))
            goto error;

        if (!X509_gmtime_adj(X509_get_notBefore(x509), gCertificateValidationWindow) ||
            !X509_gmtime_adj(X509_get_notAfter(x509), gCertificateLifetime))
            goto error;

        if (!X509_sign(x509, pkey, EVP_sha256()))
            goto error;

        BN_free(serialNumber);
        X509_NAME_free(name);
        return x509;

    error:
        BN_free(serialNumber);
        X509_NAME_free(name);
        X509_free(x509);
        return NULL;
    }

    DtlsIdentity generateIdentity()
    {
        DtlsIdentity id;
        id.key = generateRsaKeyPair();
        id.certificate = generateCertificate(id.key, "TestCompany Inc");
        return id;
    }

    void logOpenSslErrors(const std::string& prefix)
    {
        char errorBuf[200];
        unsigned long err;

        while ((err = ERR_get_error()) != 0)
        {
            ERR_error_string_n(err, errorBuf, sizeof(errorBuf));
            LOG_E(prefix << ": " << errorBuf);
        }
    }
}

Reply via email to