#! /usr/bin/perl -w

use warnings;
use strict;
use Socket qw(IPPROTO_TCP TCP_NODELAY);
use IO::Socket;
use IO::Poll;
use Time::HiRes qw(gettimeofday tv_interval);
use Net::SSLeay qw(ERROR_WANT_READ ERROR_WANT_WRITE);

use constant {
    SSL_CTRL_SET_TLSEXT_HOSTNAME	=> 55,
};

$ENV{LANG} = "C";		# Deterministic sorting and regexp behaviour
$SIG{PIPE} = "IGNORE";		# Avoid premature death when remote socket closes

# All timeouts are "deadline" timeouts.
#
my %timeout = (
    connect	=> 30,
    tls		=> 30,
);

### --- One time global SSL initialization ---

Net::SSLeay::load_error_strings();
Net::SSLeay::SSLeay_add_ssl_algorithms();
Net::SSLeay::randomize();
my $sslctx = Net::SSLeay::CTX_new();
die sprintf "Error creating SSL context: %s\n", ssl_error() if (! $sslctx);

# No legacy SSL 2.0/3.0 protocols.
#
my $ssl_options = &Net::SSLeay::OP_ALL;
$ssl_options |= &Net::SSLeay::OP_NO_SSLv2;
$ssl_options |= &Net::SSLeay::OP_NO_SSLv3;
Net::SSLeay::CTX_set_options($sslctx, $ssl_options);

# Season to taste.
#
Net::SSLeay::CTX_set_cipher_list($sslctx, join(":",
    "DEFAULT",	# ALL:+RC4:!aNULL:!eNULL:@STRENGTH
    "!RC4",	# But not RC4
    "!EXPORT",	# At least medium
    "!LOW",	# -"-
    "!MD5",	# Rules out all SSLv2 ciphers
    "!DSS",	# Nobody should be using DSS/DSA by now
    "!SEED",	# Too exotic
    "!IDEA",	# -"-
    "!RC2",	# -"-
));

### --- SSL I/O routines ---

sub poll_wait {
    my ($conn, $ev, $why) = @_;
    my $p = IO::Poll->new();

    die "$why\n" if $conn->{timeout} <= 0;

    my $t0 = [gettimeofday];
    $p->mask($conn->{sock} => $ev);
    $p->poll($conn->{timeout});
    my $elapsed = int(1000 * tv_interval($t0));
    $conn->{timeout} -= $elapsed if ($elapsed > 0);

    $ev = $p->events($conn->{sock});
    die "$why\n" if !defined($ev);
}

sub sslread {
    my ($conn) = @_;
    my $ssl = $conn->{ssl};
    my $got;

    LOOP: {
	($got, my $status) = Net::SSLeay::read($ssl);
	if (! defined($got)) {
	    my $err =  Net::SSLeay::get_error($ssl, $status ? $status : -1);
	    if ($err == 0) {
		last LOOP;
	    } elsif ($err == ERROR_WANT_READ) {
		poll_wait($conn, POLLIN, "SSL read timeout");
		redo LOOP;
	    } elsif ($err == ERROR_WANT_WRITE) {
		poll_wait($conn, POLLOUT, "SSL read timeout");
		redo LOOP;
	    }
	    die sprintf "SSL read error: %d: %s\n", $err, ssl_error();
	}
    }

    die "Unexpected read EOF\n" if ($got eq "");
    $conn->{rbuf} .= $got;
    return;
}

sub bufread {
    my ($conn) = @_;

    if ($conn->{ssl}) { return sslread($conn); }

    poll_wait($conn, POLLIN, "Read timeout");
    my $n = sysread($conn->{sock}, $conn->{rbuf}, 4096, length($conn->{rbuf}));
    die "Unexpected read EOF\n" if ($n == 0);
    die "Read error: $!\n" if ($n < 0);
}

sub sslwrite {
    my ($conn) = @_;
    my $ssl = $conn->{ssl};
    my $n;

    LOOP: {
	my $n = Net::SSLeay::write($ssl, $conn->{wbuf});
	return $n if (defined($n) && $n > 0);
	my $err =  Net::SSLeay::get_error($ssl, $n ? $n : -1);
	if ($err == ERROR_WANT_READ) {
	    poll_wait($conn, POLLIN, "SSL write timeout");
	    redo LOOP;
	} elsif ($err == ERROR_WANT_WRITE) {
	    poll_wait($conn, POLLOUT, "SSL write timeout");
	    redo LOOP;
	} else {
	    die sprintf "SSL write error: %d: %s\n", $err, ssl_error();
	}
    }
}

sub bufwrite {
    my ($conn) = @_;
    my $n;

    if ($conn->{ssl}) {
	$n = sslwrite($conn);
    } else {
	poll_wait($conn, POLLOUT, "Write timeout");
	$n = syswrite($conn->{sock}, $conn->{wbuf}, length($conn->{wbuf}));
    }
    die "Write error: $!\n" if ($n <= 0);

    # Flush written output
    #
    $conn->{wbuf} = substr($conn->{wbuf}, $n);
}

sub ssl_error {
    Net::SSLeay::ERR_error_string(Net::SSLeay::ERR_get_error());
}

sub sslconnect {
    my ($conn, $ssl) = @_;

    $conn->{timeout} = $timeout{tls} * 1000;
    LOOP: {
	return if ((my $status = Net::SSLeay::connect($ssl)) > 0);
	die "SSL shutdown on connect\n" if ($status == 0);
	my $err =  Net::SSLeay::get_error($ssl, $status);
	if ($err == ERROR_WANT_READ) {
	    poll_wait($conn, POLLIN, "SSL connect timeout");
	    redo LOOP;
	} elsif ($err == ERROR_WANT_WRITE) {
	    poll_wait($conn, POLLOUT, "SSL connect timeout");
	    redo LOOP;
	} else {
	    die sprintf "SSL connect error: %d: %s\n", $err, ssl_error();
	}
    }
}

sub sslshutdown {
    my ($conn) = @_;
    my $ssl = delete $conn->{ssl};
    my $once = 0;

    $conn->{timeout} = $timeout{tls} * 1000;
    LOOP: {
	my $status = Net::SSLeay::shutdown($ssl);
	last LOOP if ($status == 1);
	redo LOOP if ($status == 0 && ++$once == 1);
	my $err =  Net::SSLeay::get_error($ssl, $status);
	if ($err == ERROR_WANT_READ) {
	    poll_wait($conn, POLLIN, "SSL shutdown timeout");
	    redo LOOP;
	} elsif ($err == ERROR_WANT_WRITE) {
	    poll_wait($conn, POLLOUT, "SSL shutdown timeout");
	    redo LOOP;
	}
    }
    Net::SSLeay::free($ssl);
}

sub tls_version {
    my ($version) = @_;

    return "SSLv2" if ($version == 0x0002);
    return "SSLv3" if ($version == 0x0300);
    return "TLSv1" if ($version == 0x0301);
    return sprintf "TLSv1.%d", ($version - 0x0301)
	if ($version > 0x0301 && $version <= 0x03FF);
    return sprintf "unknown(%04x)", $version;
}

sub dossl {
    my ($conn, $sni) = @_;
    my $ssl = Net::SSLeay::new($sslctx);
    die sprintf "Error creating SSL handle: %s\n", ssl_error() if (! $ssl);

    Net::SSLeay::set_fd($ssl, fileno($conn->{sock}));

    # XXX: Gross, but best we can do for now given the API
    #
    Net::SSLeay::ctrl($ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, 0, $sni);

    sslconnect($conn, $ssl);
    $conn->{ssl} = $ssl;
    my $sslinfo = $conn->{sslinfo} = {};

    $sslinfo->{version} = tls_version(Net::SSLeay::version($ssl));
    $sslinfo->{cipher} = Net::SSLeay::get_cipher($ssl);
    $sslinfo->{bits} = Net::SSLeay::get_cipher_bits($ssl);

    # This does not increment the chain cert reference counts,
    # so we encode to PEM before closing the SSL connection.
    #
    $sslinfo->{chain} = [ map { Net::SSLeay::PEM_get_string_X509($_) }
	Net::SSLeay::get_peer_cert_chain($ssl) ];
    return;
}

sub donessl {
    my ($conn) = @_;

    sslshutdown($conn) if ($conn->{ssl});
    $conn->{sock}->close();
    delete $conn->{sock};
}

sub sslhost {
    my ($host, $port, $sni) = @_;
    my $sock = IO::Socket::INET->new(
	PeerAddr  => $host,
	PeerPort  => $port,
	Proto     => "tcp",
	Timeout   => $timeout{connect},
	Blocking  => 0,
    );
    die "Connection failed: $!\n" if (! defined($sock));
    $sock->autoflush(1);
    setsockopt($sock, SOL_SOCKET, SO_KEEPALIVE, 1);
    setsockopt($sock, IPPROTO_TCP, TCP_NODELAY, 1);

    my $conn = {
	sock => $sock,
	rbuf => "",
	wbuf => "",
    };

    eval {
	#
	# XXX: Insert pre-SSL handshake here
	# Use bufread() and bufwrite().
	# Reset the timeout after each logical operation
	#
	dossl($conn, $sni);
	#
	# Insert Post-SSL interaction here
	# Reset the timeout after each logical operation
	#
    };
    my $err = $@;
    donessl($conn);
    die $err if $err;
    return $conn->{sslinfo};
}

### --- Main

my ($host, $port) = @ARGV;
my $sslinfo = sslhost($host, $port, $host);
printf ";; SSL: protocol = %s, cipher = %s (%d bits)\n",
    $sslinfo->{version}, $sslinfo->{cipher}, $sslinfo->{bits};
print join("\n", @{$sslinfo->{chain}});
