/*******************************************************************************
 * Copyright 1999 PLATINUM technology IP, inc.  All rights reserved.
 *
 * This software is provided 'as is'.  See README for details.
 *
 * Description:	OpenSSL backend.
 *
 * History:
 *   Apr-99 RR  Created.
 ******************************************************************************/

#include <stdio.h>
#include "ssl.h"
#include "pem.h"

#include "ssld_.h"
#include "rsinfo.h"

RSINFO_FILENAME

struct ctx_t
{
    SSL *ssl;
    int ctx;
    int ref;
    int type;
    int state;
};

static struct ctx_t ssls[10];

static SSL_CTX *ssl_ctx_client;
static SSL_CTX *ssl_ctx_server;
static int verify_depth=5;

static BIO *bio_err;

static RSA *tmp_rsa_cb(SSL *s, int export, int keylength);
static void apps_ssl_info_callback();
static int verify_callback();
static long bio_dump_cb();
static void local_get_errors(void);

/* handle SSL requests */

int
ssld_ssl_init(char *cert_file, char *ca_file, int client_auth, int use_tls)
{
    RSINFO("ssld_ssl_init()");
    /* bio_err = BIO_new_fp(stdout, BIO_NOCLOSE); */
    bio_err = BIO_new(BIO_s_mem());
    SSL_load_error_strings();
    SSLeay_add_all_algorithms();
    /* SSLeay_add_ssl_algorithms(); */

    if (!(ssl_ctx_server = SSL_CTX_new(use_tls ?
	TLSv1_server_method() : SSLv3_server_method())))
    {
	local_get_errors();
	return -1;
    }
    SSL_CTX_set_options(ssl_ctx_server, SSL_OP_NO_SSLv2);
    SSL_CTX_set_quiet_shutdown(ssl_ctx_server, 1);
    if (rs_debug > RS_DEBUG_WARN)
	SSL_CTX_set_info_callback(ssl_ctx_server, apps_ssl_info_callback);
    SSL_CTX_sess_set_cache_size(ssl_ctx_server, 128);
    if ((!SSL_CTX_load_verify_locations(ssl_ctx_server, ca_file, 0)) ||
	(!SSL_CTX_set_default_verify_paths(ssl_ctx_server)))
    {
	RSFATAL();
	local_get_errors();
    }
    if (!set_cert_stuff(ssl_ctx_server, cert_file, 0))
	return -1;
    SSL_CTX_set_tmp_rsa_callback(ssl_ctx_server, tmp_rsa_cb);
#ifdef UNUSED
    if (cipher != NULL)
	SSL_CTX_set_cipher_list(ssl_ctx_server, cipher);
#endif /* UNUSED */
    SSL_CTX_set_verify(ssl_ctx_server, client_auth ? SSL_VERIFY_PEER |
	    SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE :
	    SSL_VERIFY_PEER,
	verify_callback);

#ifdef UNUSED
    SSL_CTX_set_client_CA_list(ssl_ctx_server, SSL_load_client_CA_file(cert_file));
#endif /* UNUSED */

    if (!(ssl_ctx_client = SSL_CTX_new(use_tls ?
	TLSv1_client_method() : SSLv3_client_method())))
    {
	local_get_errors();
	return -1;
    }
    SSL_CTX_set_options(ssl_ctx_client, SSL_OP_NO_SSLv2);
    SSL_CTX_set_quiet_shutdown(ssl_ctx_client, 1);
    if (rs_debug > RS_DEBUG_WARN)
	SSL_CTX_set_info_callback(ssl_ctx_client, apps_ssl_info_callback);
    if ((!SSL_CTX_load_verify_locations(ssl_ctx_client, ca_file, 0)) ||
	(!SSL_CTX_set_default_verify_paths(ssl_ctx_client)))
    {
	RSFATAL();
	local_get_errors();
    }
    if (!set_cert_stuff(ssl_ctx_client, cert_file, 0))
	return -1;
    SSL_CTX_set_tmp_rsa_callback(ssl_ctx_client, tmp_rsa_cb);
#ifdef UNUSED
    if (cipher != NULL)
	SSL_CTX_set_cipher_list(ssl_ctx_client, cipher);
#endif /* UNUSED */
    SSL_CTX_set_verify(ssl_ctx_client, SSL_VERIFY_PEER, verify_callback);
    /* SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT|SSL_VERIFY_CLIENT_ONCE */
    if (SSL_CTX_use_certificate_file(ssl_ctx_client, cert_file,
	SSL_FILETYPE_PEM) <= 0)
    {
	local_get_errors();
	return -1;
    }
    if (SSL_CTX_use_PrivateKey_file(ssl_ctx_client, cert_file,
	SSL_FILETYPE_PEM) <= 0)
    {
	local_get_errors();
	return -1;
    }
    if (!SSL_CTX_check_private_key(ssl_ctx_client))
    {
	RSWARN("Private key does not match public key");
	return -1;
    }

#ifdef UNUSED
    SSL_CTX_set_client_CA_list(ssl_ctx_client, SSL_load_client_CA_file("client.pem"));
#endif /* UNUSED */

    return 0;
}

SSLD_SSL
ssld_ssl_ctx(int ctx, int ref, int type)
{
    SSL *con;
    BIO *rbio, *wbio;
    int i;

    RSINFO3("ssld_ssl_ctx(%d, %d, %d)", ctx, ref, type);
    if (!(rbio = BIO_new(BIO_s_mem())) || !(wbio = BIO_new(BIO_s_mem())))
    {
	/* conditionally free read BIO */
	RSWARN("Cannot allocate context");
	return -1;
    }
    if (!(con = SSL_new(type == SSLD_CLIENT ? ssl_ctx_client : ssl_ctx_server)))
    {
	/* free BIOs */
	RSWARN("Cannot allocate context");
	local_get_errors();
	return -1;
    }
    SSL_clear(con);
    SSL_set_bio(con, rbio, wbio);
    if (type == SSLD_CLIENT)
	SSL_set_connect_state(con);
    else if (type == SSLD_SERVER)
	SSL_set_accept_state(con);
    else
    {
	RSFATAL();
	return -1;
    }
    if (rs_debug > RS_DEBUG_WARN)
    {
	con->debug = 1;
	BIO_set_callback(SSL_get_rbio(con), bio_dump_cb);
	/* BIO_set_callback_arg(SSL_get_rbio(con), bio_err); */
    }

    for (i = 0; i < 10; ++i)
	if (!ssls[i].ssl)
	{
	    ssls[i].ssl = con;
	    ssls[i].ctx = ctx;
	    ssls[i].ref = ref;
	    ssls[i].type = type;
	    ssls[i].state = 0;
	    break;
	}
    if (i == 10)
    {
	RSFATAL();
	/* free con */
	return -1;
    }
    return i;
}

void
ssld_ssl_free(SSLD_SSL ssl)
{
    int i;

    RSINFO1("ssld_ssl_free(%d)", ssl);
    if (ssl >= 10)
    {
	RSFATAL();
	return;
    }
    /* free ssls[ctx].ssl */
    ssls[ssl].ssl = 0;
}

int
ssld_ssl_request(SSLD_SSL ssl, int op, int length, char *data)
{

    char buf[2000];
    struct ctx_t *ctx;
    SSL *con;
    X509 *peer;
    X509_NAME *subject;
    char *str;
    int error;
    int sent = 0;
    int len;
    int verify_error;

    RSINFO3("ssld_ssl_request(%d, %d, %d)", ssl, op, length);
    if (ssl >= 10)
    {
	RSFATAL();
	return -1;
    }
    ctx = &ssls[ssl];
    if (!(con = ctx->ssl))
    {
	RSFATAL();
	return -1;
    }

    switch (op)
    {
    case SSLD_OP_USER_DATA:	/* user data */
	error = 0;
	while (error == 0)
	{
	    len = SSL_write(con, data, length);
RSINFO2("case 0: SSL_write(%d) --> %d", length, len);
	    if (len > 0)
	    {
		length -= len;
		data += len;
		sent += len;
	    }
	    switch (SSL_get_error(con, len))
	    {
	    case SSL_ERROR_NONE:
	    case SSL_ERROR_SYSCALL:	/* there is no sys call */
		break;
	    case SSL_ERROR_WANT_WRITE:
		RSINFO("Write block");
		RSFATAL();
		return -1;
	    case SSL_ERROR_WANT_READ:
		RSINFO("Read block");
		error = 1;	/* break out of loop */
		break;
	    case SSL_ERROR_WANT_X509_LOOKUP:
		RSINFO("X509 block");
		RSFATAL();
		return -1;
	    case SSL_ERROR_SSL:
		RSWARN("ERROR");
		local_get_errors();
		return -1;
	    case SSL_ERROR_ZERO_RETURN:
		RSINFO("Complete");
		return -1;
	    default:
		RSFATAL();
		return -1;
	    }
	    if (length == 0)
		break;
	}
	break;
    case SSLD_OP_PEER_DATA:	/* peer data */
	BIO_write(con->rbio, data, length);
RSINFO1("case 1: BIO_write(%d)", length);
	sent = length;
	if (con->state != SSL_ST_OK && ctx->type == SSLD_SERVER)
	{
	    if ((len = SSL_accept(con)) <= 0)
	    {
		if (!SSL_want_read(con) ||
		    !BIO_should_read(SSL_get_rbio(con)))
		{
		    RSEVENT("ERROR");
		    verify_error=SSL_get_verify_result(con);
		    if (verify_error != X509_V_OK)
			RSEVENT1("verify error:%s",
			    X509_verify_cert_error_string(verify_error));
		    else
			    local_get_errors();
		    return -1;	/* error */
		}
	    }

	    while ((len = BIO_read(con->wbio, buf, 2000)) > 0)
		ssld_ssl_response(ctx->ctx, ctx->ref, SSLD_OP_PEER_DATA,
		    len, buf);
	    if (con->state != SSL_ST_OK)
		return sent;
	}

	if (!ctx->state && !SSL_in_init(con))
	{
	    ctx->state = 1;

#ifdef UNUSED
	    PEM_write_bio_SSL_SESSION(bio_err,SSL_get_session(con));
#endif /* UNUSED */

	    peer=SSL_get_peer_certificate(con);
	    if (peer != NULL)
	    {
		RSINFO("Got peer certificate");
		subject = X509_get_subject_name(peer);
		ssld_ssl_response(ctx->ctx, ctx->ref, SSLD_OP_SUBJECT,
		    subject->bytes->length, subject->bytes->data);

#ifdef UNUSED
		BIO_printf(bio_err,"Peer certificate\n");
		PEM_write_bio_X509(bio_err,peer);
		X509_NAME_oneline(subject,buf,BUFSIZ);
		RSINFO1("Subject = %s", buf);
		X509_NAME_oneline(X509_get_issuer_name(peer),buf,BUFSIZ);
		RSINFO1("Issuer = %s",buf);
#endif /* UNUSED */
		X509_free(peer);
	    }

#ifdef UNUSED
	    if (SSL_get_shared_ciphers(con,buf,BUFSIZ) != NULL)
		RSINFO1("Shared ciphers: %s", buf);
	    str = SSL_CIPHER_get_name(SSL_get_current_cipher(con));
	    RSINFO1("CIPHER is %s", (str != NULL) ? str : "(NONE)");
	    if (con->hit) RSINFO("Reused session-id");
#endif /* UNUSED */
	}

	len = 1;
	while (len > 0)
	{
	    len = SSL_read(con, buf, 2000);
RSINFO1("case 1: SSL_read() --> %d", len);
	    switch (SSL_get_error(con, len))
	    {
	    case SSL_ERROR_NONE:
		ssld_ssl_response(ctx->ctx, ctx->ref, SSLD_OP_USER_DATA,
		    len, buf);
		break;
	    case SSL_ERROR_WANT_WRITE:
	    case SSL_ERROR_WANT_READ:
	    case SSL_ERROR_WANT_X509_LOOKUP:
		len = 0;
		break;
	    case SSL_ERROR_SYSCALL:
	    case SSL_ERROR_SSL:
		RSWARN("SSL Error");
		local_get_errors();
		/* close con and ref */
		return -1;	/* error */
	    case SSL_ERROR_ZERO_RETURN:
		RSEVENT("SSL Done");
		SSL_set_shutdown(con,SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
RSINFO("Closing");
		/* close con and ref */
		return -1;	/* done but not error */
	    default:
		RSFATAL();
		exit(99);
	    }
	}
	break;
    default:
	RSFATAL();
	return -1;
    }
    while ((len = BIO_read(con->wbio, buf, 2000)) > 0)
    {
RSINFO1("case *: BIO_read() --> %d", len);
	ssld_ssl_response(ctx->ctx, ctx->ref, SSLD_OP_PEER_DATA, len, buf);
    }
    return sent;
}

static RSA *tmp_rsa_cb(SSL *s, int export, int keylength)
{
    static RSA *rsa_tmp=NULL;

    if (rsa_tmp == NULL)
    {
	RSEVENT1("Generating temp (%d bit) RSA key", keylength);
#ifndef NO_RSA
	rsa_tmp=RSA_generate_key(keylength,RSA_F4,NULL,NULL);
#endif
    }
    return(rsa_tmp);
}

/*============================ s_cb.c ===============================*/

static int verify_callback(int ok, X509_STORE_CTX *ctx)
{
    char buf[256];
    X509 *err_cert;
    int err,depth;

    err_cert=X509_STORE_CTX_get_current_cert(ctx);
    err= X509_STORE_CTX_get_error(ctx);
    depth= X509_STORE_CTX_get_error_depth(ctx);

    X509_NAME_oneline(X509_get_subject_name(err_cert),buf,256);
    RSINFO2("verify_callback: depth=%d: %s", depth, buf);
    if (!ok)
    {
	RSWARN2("Verify error %d: %s",err,
	    X509_verify_cert_error_string(err));
#ifdef UNUSED
	if (verify_depth >= depth)
	{
	    ok=1;
	    verify_error=X509_V_OK;
	}
	else
	{
	    /* ok=0; */
	    verify_error=X509_V_ERR_CERT_CHAIN_TOO_LONG;
	}
#endif /* UNUSED */
    }
    else if (verify_depth < depth)
    {
	RSWARN("Certificate chain too long");
	ok = 0;
	/* verify_error=X509_V_ERR_CERT_CHAIN_TOO_LONG; */
	ctx->error = X509_V_ERR_CERT_CHAIN_TOO_LONG;
    }

    switch (ctx->error)
    {
    case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT:
	X509_NAME_oneline(X509_get_issuer_name(ctx->current_cert),buf,256);
	RSEVENT1("Cannot get issuer certificate: %s",buf);
	break;
    case X509_V_ERR_CERT_NOT_YET_VALID:
    case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD:
	RSEVENT("Error in notBefore field");
#ifdef UNUSED
	BIO_printf(bio_err,"notBefore=");
	ASN1_TIME_print(bio_err,X509_get_notBefore(ctx->current_cert));
	BIO_printf(bio_err,"\n");
#endif /* UNUSED */
	break;
    case X509_V_ERR_CERT_HAS_EXPIRED:
	RSEVENT("Certificate has expired");
	break;
    case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD:
	RSEVENT("Error in notAfter field");
#ifdef UNUSED
	BIO_printf(bio_err,"notAfter=");
	ASN1_TIME_print(bio_err,X509_get_notAfter(ctx->current_cert));
	BIO_printf(bio_err,"\n");
#endif /* UNUSED */
	break;
    }

    RSINFO1("Verify return: %d", ok);
    return(ok);
}

static int set_cert_stuff(SSL_CTX *ctx, char *cert_file, char *key_file)
{
    if (cert_file != NULL)
    {
	/*
	SSL *ssl;
	X509 *x509;
	*/

	if (SSL_CTX_use_certificate_file(ctx,cert_file,
	    SSL_FILETYPE_PEM) <= 0)
	{
	    RSEVENT1("Unable to get certificate from '%s'", cert_file);
	    local_get_errors();
	    return(0);
	}
	if (key_file == NULL) key_file=cert_file;
	if (SSL_CTX_use_PrivateKey_file(ctx,key_file, SSL_FILETYPE_PEM) <= 0)
	{
	    RSEVENT1("Unable to get private key from '%s'", key_file);
	    local_get_errors();
	    return(0);
	}

	/* Now we know that a key and cert have been set against
	 * the SSL context */
	if (!SSL_CTX_check_private_key(ctx))
	{
	    RSEVENT("Private key does not match the certificate public key");
	    return(0);
	}
    }
    return(1);
}

static long bio_dump_cb(BIO *bio, int cmd, char *argp, int argi, long argl, long ret)
{
    if (cmd == (BIO_CB_READ|BIO_CB_RETURN))
    {
	RSINFO3("Read %d bytes from %x: returned %ld", argi, bio, ret);
	RSDUMP(argp, ret);
    }
    else if (cmd == (BIO_CB_WRITE|BIO_CB_RETURN))
    {
	RSINFO3("Write %d bytes to %x: returned %ld", argi, bio, ret);
	RSDUMP(argp, ret);
    }
    return(ret);
}

static void apps_ssl_info_callback(SSL *s, int where, int ret)
{
    char *str;
    int w;

    if (rs_debug < RS_DEBUG_EVENT)
	return;

    w=where& ~SSL_ST_MASK;

    if (w & SSL_ST_CONNECT) str="SSL_connect";
    else if (w & SSL_ST_ACCEPT) str="SSL_accept";
    else str="undefined";

    if (where & SSL_CB_LOOP)
    {
	RSINFO2("STATE: %s: %s", str, SSL_state_string_long(s));
    }
    else if (where & SSL_CB_ALERT)
    {
	str=(where & SSL_CB_READ)?"read":"write";
	RSINFO3("STATE: SSL3 alert %s: %s: %s",
	    str,
	    SSL_alert_type_string_long(ret),
	    SSL_alert_desc_string_long(ret));
    }
    else if (where & SSL_CB_EXIT)
    {
	if (ret == 0)
	    RSINFO2("STATE: %s: failed in %s", str, SSL_state_string_long(s));
	else if (ret < 0)
	{
	    RSINFO2("STATE: %s: error in %s", str, SSL_state_string_long(s));
	}
    }
}

static void
local_get_errors(void)
{
    char buffer[402];
    int len;

    ERR_print_errors(bio_err);
    while ((len = BIO_read(bio_err, buffer, 400)) > 0)
    {
	buffer[len] = 0;
	RSINFO(buffer);
    }
}
