#!/usr/bin/perl

package Proxy;

use strict;
use warnings;
use IO::Socket::INET;
use IO::Select;
use Carp;

sub new {
	my ($class, $options) = @_;
	my $self = {
		ioset => IO::Select->new,
		sockets => {},
		names => {},
		src_host => $options->{src_host} // '0.0.0.0',
		dst_host => $options->{dst_host} // '0.0.0.0',
		src_port => $options->{src_port} // 11111,
		dst_port => $options->{dst_port} // 5432,
		running => 1,
		drop => 0,
		conn_count => 0,
		debug => $options->{debug} // 0,
		allowed_connections => $options->{allowed_connections} // 1,
	};
	bless $self, $class;
}

sub _new_conn {
	my ($self, $host, $port) = @_;
	return IO::Socket::INET->new(
		PeerAddr => $host,
		PeerPort => $port,
		#Blocking  => 0,
	) || croak "Unable to connect to $host:$port: $!";
}

sub _new_server {
	my ($self, $host, $port) = @_;
	return IO::Socket::INET->new(
		LocalAddr => $host,
		LocalPort => $port,
		ReuseAddr => 1,
		Listen    => 10,
		#Blocking  => 0,
	) || croak "Unable to listen on $host:$port: $!";
}

sub _new_connection {
	my ($self, $server, $remote_host, $remote_port) = @_;

	my $client = $server->accept;

	my $remote = $self->_new_conn($remote_host, $remote_port);
	$self->{ioset}->add($client);
	$self->{ioset}->add($remote);

	$self->{sockets}->{$client} = $remote;
	$self->{sockets}->{$remote} = $client;
	$self->{conn_count}++;
	$self->{names}->{$client} = $self->{src_port}.'-'.$self->{conn_count};
	$self->{names}->{$remote} = $remote_port.'-'.$self->{conn_count};
	print "  new connection $self->{names}->{$client}\n" if $self->{debug};
}

sub _close_connection {
	my ($self, $client) = @_;
	my $remote = $self->{sockets}->{$client};

	if ($self->{debug}) {
		print "  closing connection ".($self->{names}->{$client} // '')."\n";
	}

	$self->{ioset}->remove($client);
	$self->{ioset}->remove($remote);

	delete $self->{sockets}->{$client};
	delete $self->{sockets}->{$remote} if defined $remote;

	delete $self->{names}->{$client};
	delete $self->{names}->{$remote} if defined $remote;

	$client->close;
	$remote->close if defined $remote;
	if ($self->{conn_count}) {
		$self->{conn_count}--;
	}
}

sub run {
	my $self = shift;

	my $server = $self->_new_server(
		$self->{src_host},
		$self->{src_port}
	);
	$self->{ioset}->add($server);

	while ($self->{running}) {
		for my $socket ($self->{ioset}->can_read()) {
			if ($socket == $server) {
				if ($self->{conn_count} < $self->{allowed_connections}) {
					$self->_new_connection(
						$server,
						$self->{dst_host},
						$self->{dst_port}
					);
				}
			}
			else {
				next unless exists $self->{sockets}->{$socket};
				my $remote = $self->{sockets}->{$socket};
				my $buffer;
				my $read = $socket->sysread($buffer, 4096);
				if ($read) {
					print '  '.($self->{drop} ? 'not ': '')."forwarding $read bytes from $self->{names}->{$socket} to $self->{names}->{$remote}\n" if $self->{debug};
					hexdump($buffer) if $self->{debug} > 1;
					$remote->syswrite($buffer) unless $self->{drop};
				}
				else {
					$self->_close_connection($socket);
				}
			}
		}
	}
	$self->_close_connection($server);
}

sub hexdump {
	my $offset = 0;

	foreach my $chunk (unpack "(a16)*", $_[0]) {
		my $hex = unpack "H*", $chunk; # hexadecimal magic
		$chunk =~ tr/ -~/./c;          # replace unprintables
		$hex   =~ s/(.{1,8})/$1 /gs;   # insert spaces
		printf "0x%08x (%05u)  %-*s %s\n",
			$offset, $offset, 36, $hex, $chunk;
		$offset += 16;
	}
}

package main;

use strict;
use warnings;
use DBI;
use DBD::Pg;
use POSIX qw(:signal_h);
use Time::HiRes qw(ualarm time usleep);
use Getopt::Long qw(:config bundling);

my $timeout = 6;
my $pg_sleep = 3;
my $kill_interval = 1;

my $options = {
	drop   => 0,
	cancel => 0,
	verbose => 0,
	allowed_connections => 2,
};
GetOptions(
	$options,
	'drop|d!',
	'cancel|c!',
	'verbose|v+',
	'allowed_connections|a=i',
) or die 'usage: pg_cancel_bug.pl [-d] [-c] [-v] [-a CONNECTIONS]';

my $proxy_child = fork;
defined $proxy_child or die "Can't fork proxy: $!\n";

if ($proxy_child == 0) {
	my $proxy = Proxy->new({
		debug               => $options->{verbose},
		allowed_connections => $options->{allowed_connections},
	});
	$SIG{USR1} = sub { $proxy->{drop}    = 1; };
	$SIG{TERM} = sub { $proxy->{running} = 0; };
	$proxy->run;
	exit;
}

my $naughty_child = fork;
defined $naughty_child or die "Can't fork naughty child: $!\n";

if ($naughty_child) {
	my ($dbh, $sth);
	eval {
		$dbh = DBI->connect(
			'dbi:Pg:dbname=postgres;port=11111;host=localhost',
			'postgres', undef,
			{ RaiseError => 1, AutoCommit => 1 }) or die 'connect: '.$DBI::errstr;
		$sth = $dbh->prepare(qq{SELECT 'ok', pg_sleep($pg_sleep)}) or die 'prepare: '.$DBI::errstr;
		my $res = timeout_wrap($timeout, $sth, sub {
			$sth->execute() or die 'execute: '.$DBI::errstr;
			my $res = $sth->fetchall_arrayref() or die 'fetch: '.$DBI::errstr;;
			return $res->[0]->[0];
		});
		print $res."\n" if $res;
		1;
	} or do {
		warn $@;
	};
	kill 'TERM', $proxy_child;
}
else {
	usleep($kill_interval * 1_000_000);

	kill 'USR1', $proxy_child if $options->{drop};
	usleep(50_000);

	kill 'INT', getppid() if $options->{cancel};
}

sub timeout_wrap {
	my ($timeout, $sth, $callback) = @_;

	my $mask = POSIX::SigSet->new(SIGALRM);
	my $action = POSIX::SigAction->new(
		sub {die {error => 'timeout'} },
		$mask
		);
	my $maskint = POSIX::SigSet->new(SIGINT);
	my $actint = POSIX::SigAction->new(
		sub {die {error => 'interrupt'} },
		$maskint
		);
	my $oldaction = POSIX::SigAction->new();
	sigaction(SIGINT, $actint, $oldaction);
	sigaction(SIGALRM, $action, $oldaction);
	my $result;
	eval {
		eval {
			ualarm $timeout*1_000_000; # seconds before time out
			$result = $callback->();
		};
		alarm 0;
		die $@ if $@;
	};
	sigaction(SIGINT, $oldaction);
	sigaction(SIGALRM, $oldaction);
	if ( ref $@ eq 'HASH') {
		if ($@->{error} eq 'interrupt') {
			print "cancel!\n";
			$sth->cancel();
			die $@->{error};
		}
		else {
			die $@->{error};
		}
	}
	elsif ($@) {
		die $@;
	}
	else {
		return $result;
	}
}
