/*
 * Copyright (C) 2009 Robin Seggelmann, seggelmann@fh-muenster.de,
 *                    Michael Tuexen, tuexen@fh-muenster.de
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of the project nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE PROJECT AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE PROJECT OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#ifdef WIN32
#include <winsock2.h>
#elif SOLARIS2 || SOLARIS8 || AIX || LINUX


#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <pthread.h>
#include <time.h>
#include <sys/time.h>
#include <sys/select.h>
#include <fcntl.h>

#endif

#ifdef __linux__
#include <getopt.h>
#endif

#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/rand.h>

#define COOKIE_SECRET_LENGTH 16
#define BUFFER_SIZE          (1<<16)
#define MAX_CONNECTIONS 5

unsigned char cookie_secret[COOKIE_SECRET_LENGTH];
int cookie_initialized=0;

int verbose = 1;
int veryverbose = 1;

fd_set accept_rfds;
fd_set read_rfds;

struct fdinfo
{
  int fd;
  void *token;
  int count;
};

struct fdinfo acceptfd_set[MAX_CONNECTIONS];
struct fdinfo readfd_set[MAX_CONNECTIONS];

char Usage[] =
"Usage: dtls_udp_echo [options] [address]\n"
"Options:\n"
"        -l      message length (Default: 100 Bytes)\n"
"        -p      port (Default: 23232)\n"
"        -n      number of messages to send (Default: 5)\n"
"        -v      verbose\n"
"        -V      very verbose\n";



struct pass_info {
	struct sockaddr_in server_addr;
	struct sockaddr_in client_addr;
	SSL *ssl;
};


int getMaxfd(struct fdinfo *set)
{
  int i;
  int max = 0;
 

  for (i=1; i < MAX_CONNECTIONS; i++)
  {
    if (set[max].fd < set[i].fd)
      max = i;
  }

  if (set[max].fd == -1)
    return -1;

  return (max);

}

void clearFDSet(struct fdinfo *set, int fd)
{
  int i;

  for (i = 0; i < MAX_CONNECTIONS; i++)
  {
    if (set[i].fd == fd) 
    {
      set[i].fd = -1;
      set[i].token = NULL;
      set[i].count = 0;
    }
    
  }
}

void setFDSet(struct fdinfo *set, int fd, void *token)
{
  int i;

  for (i = 0; i < MAX_CONNECTIONS; i++)
  {
    if (set[i].fd == -1) 
    {
      set[i].fd = fd;
      set[i].token = token;
      set[i].count = 0;
      return;
    }
    
  }
  printf("FATAL number of connection exceeded \n");
  exit(-1);

}

int dtls_verify_callback (int ok, X509_STORE_CTX *ctx) {
	/* This function should ask the user
	 * if he trusts the received certificate.
	 * Here we always trust.
	 */
	return 1;
}


int generate_cookie(SSL *ssl, unsigned char *cookie, unsigned int *cookie_len)
	{
	unsigned char *buffer, result[EVP_MAX_MD_SIZE];
	unsigned int length = 0, resultlength;
	struct sockaddr_in peer;
	

	/* Initialize a random secret */
	if (!cookie_initialized)
		{
		if (!RAND_bytes(cookie_secret, COOKIE_SECRET_LENGTH))
			{
			printf("error setting random cookie secret\n");
			return 0;
			}
		cookie_initialized = 1;
		}

	/* Read peer information */
	(void) BIO_dgram_get_peer(SSL_get_rbio(ssl), &peer);

	/* Create buffer with peer's address and port */
	length = 0;
	switch (peer.sin_family) {
		case AF_INET:
			length += sizeof(struct in_addr);
			break;
		default:
			OPENSSL_assert(0);
			break;
	}
	length += sizeof(unsigned short);
	buffer = (unsigned char*) OPENSSL_malloc(length);

	if (buffer == NULL)
		{
		printf("out of memory\n");
		return 0;
		}

	switch (peer.sin_family) {
		case AF_INET:
			memcpy(buffer,
			       &peer.sin_port,
			       sizeof(unsigned short));
			memcpy(buffer + sizeof(peer.sin_port),
				   &peer.sin_addr,
			       sizeof(struct in_addr));
			break;
		default:
			OPENSSL_assert(0);
			break;
	}

	/* Calculate HMAC of buffer using the secret */
	HMAC(EVP_sha1(), (const void*) cookie_secret, COOKIE_SECRET_LENGTH,
	     (const unsigned char*) buffer, length, result, &resultlength);
	OPENSSL_free(buffer);
	memcpy(cookie, result, resultlength);
	*cookie_len = resultlength;

	return 1;
}

int verify_cookie(SSL *ssl, unsigned char *cookie, unsigned int cookie_len)
	{
	unsigned char *buffer, result[EVP_MAX_MD_SIZE];
	unsigned int length = 0, resultlength;
	struct sockaddr_in peer;
	
	/* If secret isn't initialized yet, the cookie can't be valid */
	if (!cookie_initialized)
		return 0;

	/* Read peer information */
	(void) BIO_dgram_get_peer(SSL_get_rbio(ssl), &peer);

	/* Create buffer with peer's address and port */
	length = 0;
	switch (peer.sin_family) {
		case AF_INET:
			length += sizeof(struct in_addr);
			break;
		default:
			OPENSSL_assert(0);
			break;
	}
	length += sizeof(unsigned short);
	buffer = (unsigned char*) OPENSSL_malloc(length);

	if (buffer == NULL)
		{
		printf("out of memory\n");
		return 0;
		}

	switch (peer.sin_family) {
		case AF_INET:
			memcpy(buffer,
			       &peer.sin_port,
			       sizeof(unsigned short));
			memcpy(buffer + sizeof(unsigned short),
			       &peer.sin_addr,
			       sizeof(struct in_addr));
			break;
		
		default:
			OPENSSL_assert(0);
			break;
	}

	/* Calculate HMAC of buffer using the secret */
	HMAC(EVP_sha1(), (const void*) cookie_secret, COOKIE_SECRET_LENGTH,
	     (const unsigned char*) buffer, length, result, &resultlength);
	OPENSSL_free(buffer);

	if (cookie_len == resultlength && memcmp(result, cookie, resultlength) == 0)
		return 1;

	return 0;
	}





int set_non_blocking(int fd)
{
  int flags;
  int result;

#if LINUX 

  /* If they have O_NONBLOCK, use the Posix way to do it */
#if defined(O_NONBLOCK)
  /* Fixme: O_NONBLOCK is defined but broken on SunOS 4.1.x and AIX 3.2.5. */
  if (-1 == (flags = fcntl(fd, F_GETFL, 0)))
    flags = 0;
  result = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
#else
  /* Otherwise, use the old way of doing it */
  flags = 1;
  result = ioctl(fd, FIOBIO, &flags);
  
#endif
  
#elif WIN32
  flags = 1;
  result = ioctlsocket(fd, FIONBIO, (unsigned long *)&flags);
  
#else
#error platform not defined!
#endif
  
  if (result != 0)
    {
      printf("%s", "Failed to set tcp socket non-blocking.");
      return (-1);
    }
  
  return (0);
  
}


int NetworkStart(void)
{

/* PLATFORM = WINDOWS */

#if WIN32

    WORD    VersionRequired;
    WSADATA WSAData;
    int     status;

    VersionRequired = MAKEWORD(2,0);

    status = WSAStartup(VersionRequired, &WSAData);
    if (status)
        return(-1);
    else
        return(0);


/* PLATFORM = ANY UNIX */

#elif SOLARIS2 || SOLARIS8 || AIX || LINUX

    return(0);


/* PLATFORM = UNDEFINED, FORCE COMPILER ERROR */

#else
#error platform not defined!
#endif
}

int NetworkTerm(void)
{

/* PLATFORM = WINDOWS */

#if WIN32

    int status;

    status = WSACleanup();
    if (status)
        return(-1);
    else
        return(0);


/* PLATFORM = ANY UNIX */

#elif SOLARIS2 || SOLARIS8 || AIX || LINUX

    return(0);


/* PLATFORM = UNDEFINED, FORCE COMPILER ERROR */

#else
#error platform not defined!
#endif
}
int connection_handle(void *info) 
{


	struct pass_info *pinfo = (struct pass_info*) info;
	SSL *ssl = pinfo->ssl;
	int fd, reading = 0, ret;
	const int on = 1;
        int count = 0;
	
	struct timeval tv;
	int retval;
        int accept = 0;

	OPENSSL_assert(pinfo->client_addr.sin_family == pinfo->server_addr.sin_family);
	fd = socket(pinfo->client_addr.sin_family, SOCK_DGRAM, 0);
	if (fd < 0) {
		perror("socket");
		exit(-1);
	}
	
	set_non_blocking(fd);
	setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (const void*) &on,  sizeof(on));
	switch (pinfo->client_addr.sin_family) {
		case AF_INET:
			bind(fd, (const struct sockaddr *) &pinfo->server_addr, sizeof(struct sockaddr_in));
			connect(fd, (struct sockaddr *) &pinfo->client_addr, sizeof(struct sockaddr_in));
			break;
		default:
			OPENSSL_assert(0);
			break;
	}

	/* Set new fd and set BIO to connected */
	BIO_set_fd(SSL_get_rbio(ssl), fd, BIO_NOCLOSE);
	BIO_ctrl(SSL_get_rbio(ssl), BIO_CTRL_DGRAM_SET_CONNECTED, 0, &pinfo->client_addr);

	if (SSL_accept(ssl) < 0)
	{    
	  char buf[BUFFER_SIZE];
	  perror("SSL_accept");
	  printf("%s\n", ERR_error_string(ERR_get_error(), buf));	  
	  FD_SET(fd, &accept_rfds);
	  setFDSet(acceptfd_set, fd, pinfo->ssl);
	}
	else
	{
	  printf("SSL accept successul first time\n");
	  FD_SET(fd, &read_rfds);
	  setFDSet(readfd_set, fd, pinfo->ssl);
	}

	return 0;
		
}


void connection_accept()
{
  struct timeval tv;	
  int i, retval;
  int maxfd_i, maxfd;
  SSL *ssl;
  fd_set temp_set;
  char buf[BUFFER_SIZE];
  
  tv.tv_sec = 1;
  tv.tv_usec = 0;
  maxfd_i =  getMaxfd(acceptfd_set);              		
 
  if (maxfd_i == -1)
    return;

  maxfd = acceptfd_set[maxfd_i].fd;
  memcpy(&temp_set, &accept_rfds, sizeof(fd_set));

  retval = select(maxfd+1, &temp_set, NULL, NULL, &tv);
  /* Don't rely on the value of tv now! */
 
  if (retval == -1) 
  {
    perror("C...select()");
    return;	
  }
  else if (retval)
  {
    printf("accept activity.\n");
  }
  /* FD_ISSET(0, &rfds) will be true. */
  else
  {
    printf("C...No data within 1 second.\n");
    return;
  }
 
  for (i = 0; i < MAX_CONNECTIONS; i++)
  {
    int fd;
    if (acceptfd_set[i].fd == -1)
      continue;
    if (!FD_ISSET(acceptfd_set[i].fd, &temp_set))
      continue;
    fd =  acceptfd_set[i].fd;
    ssl = (SSL*)acceptfd_set[i].token;

    /* Finish handshake */
    retval = SSL_accept(ssl); 
   
    if (retval == 0)
    {
      printf("SSL accept error \n");
      printf("%s\n", ERR_error_string(ERR_get_error(), buf));
      return;
    }
   
   if (retval < 0) 
   {
     perror("SSL_accept");
     printf("%s\n", ERR_error_string(ERR_get_error(), buf));
     return;
   }
   
   FD_CLR(fd, &accept_rfds);
   clearFDSet(acceptfd_set, fd);
   FD_SET(fd, &read_rfds);
   setFDSet(readfd_set, fd, ssl);
  }

            	
}


void connection_read()
{
  int shutdown = 0;
  int i, retval;
  int maxfd_i, maxfd;
  SSL *ssl;
  char buf[BUFFER_SIZE];
  int len;
  fd_set temp_set;
  struct timeval tv;

  tv.tv_sec = 1;
  tv.tv_usec= 0;
 
  maxfd_i =  getMaxfd(readfd_set);              		
 
 if (maxfd_i == -1)
   return;

 maxfd = readfd_set[maxfd_i].fd;

 memcpy(&temp_set, &read_rfds, sizeof(fd_set));
 
  retval = select(maxfd+1, &temp_set, NULL, NULL, &tv);
  /* Don't rely on the value of tv now! */
  
  if (retval == -1)
  {
    perror("D...select()");
    return;
  }
  else if (retval)
  {
      printf("Data available:\n");
   }
  /* FD_ISSET(0, &rfds) will be true. */
  else
    {
      printf(" Time out.....\n");
      return;
    }

 for (i = 0; i < MAX_CONNECTIONS; i++)
 {
   int fd;
   int *count;

   if (readfd_set[i].fd == -1)
   {
     printf ("No data to read for %d\n", i);
     continue;
   }

   if (!FD_ISSET(readfd_set[i].fd, &temp_set)) 
   {
     printf ("No fd set to read for %d\n", i);
     continue;
   }

   fd =  readfd_set[i].fd;
   count = &readfd_set[i].count;

   ssl = (SSL*)readfd_set[i].token;

   printf("SSL_read on fd = %d, ssl = %p\n", fd, ssl);
 
   len = SSL_read(ssl, buf, sizeof(buf));
   
   switch (SSL_get_error(ssl, len)) 
     {
     case SSL_ERROR_NONE:
       if (verbose) {
	 printf(" read %d bytes message: %d\n ", (int) len, (*count)++);
       }
       break;
     case SSL_ERROR_WANT_READ:
       printf("ERROR_WANT_READ.. try again\n");
       /* Just try again */
       break;
    case SSL_ERROR_ZERO_RETURN:
      printf("ERROR_ZERO_RETURN.. try again\n");
      /* SSL shutdown */
      if ((SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN)) 
	{
	  printf(" SHUTDOWN RECEIVED.: \n" );
	  shutdown = 1;
	}  
      else
	printf("SHUTDOWN NOT ....RECEIVED.: \n");
      break;
      
    default:
      printf("Unexpected error while reading: %d\n", SSL_get_error(ssl, len));
      printf("%s\n", ERR_error_string(ERR_get_error(), buf));
      goto cleanup;
      break;
    }
  
   if (shutdown)
     SSL_shutdown(ssl);
   else
     return;
   
 cleanup:
#if WIN32
   closesocket(fd);
#elif LINUX
   close(fd);
#endif
   clearFDSet(readfd_set, fd);
   FD_CLR(fd, &read_rfds);
   SSL_free(ssl);
   ERR_remove_state(0);
   if (verbose)
     printf("done, connection closed.\n");
 }
}


void start_server(int port, char *local_address)
{
	int fd;
	struct sockaddr_in client_addr;
	struct sockaddr_in server_addr;
	SSL_CTX *ctx;
	SSL *ssl;
	BIO *bio;
	struct timeval timeout;
	struct pass_info *info;
	const int on = 1;
	fd_set rfds;
	struct timeval tv;
	int retval;
	int i;
    
	for (i=  0; i < MAX_CONNECTIONS; i ++)
	{
	  memset(&acceptfd_set[i], 0, sizeof(struct fdinfo));
	  acceptfd_set[i].fd = -1;
	  memset(&readfd_set[i], 0, sizeof(struct fdinfo));
	  readfd_set[i].fd = -1;
	}

	FD_ZERO(&accept_rfds);
	FD_ZERO(&read_rfds);

	memset(&server_addr, 0, sizeof(struct sockaddr_in));
	if (strlen(local_address) == 0) {
	  server_addr.sin_family = AF_INET;
	  server_addr.sin_port = htons(port);
	} else 
	{
	  return;
	}
	
	OpenSSL_add_ssl_algorithms();
	SSL_load_error_strings();
	ctx = SSL_CTX_new(DTLSv1_server_method());
	/* We accept all ciphers, including NULL.
	 * Not recommended beyond testing and debugging
	 */
	//SSL_CTX_set_cipher_list(ctx, "ALL:NULL:eNULL:aNULL");
	SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);

	if (!SSL_CTX_use_certificate_file(ctx, "server-cert.pem", SSL_FILETYPE_PEM))
		printf("\nERROR: no certificate found!");

	if (!SSL_CTX_use_PrivateKey_file(ctx, "server-key.pem", SSL_FILETYPE_PEM))
		printf("\nERROR: no private key found!");

	if (!SSL_CTX_check_private_key (ctx))
		printf("\nERROR: invalid private key!");

	/* Client has to authenticate */
	SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE, dtls_verify_callback);
	
        SSL_CTX_set_read_ahead(ctx, 1);
	SSL_CTX_set_cookie_generate_cb(ctx, generate_cookie);
	SSL_CTX_set_cookie_verify_cb(ctx, verify_cookie);

	fd = socket(server_addr.sin_family, SOCK_DGRAM, 0);
	if (fd < 0) {
		perror("socket");
		exit(-1);
	}

	set_non_blocking(fd);
        
	setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (const void*) &on, sizeof(on));
        bind(fd, (const struct sockaddr *) &server_addr, sizeof(struct sockaddr_in));
	

	memset(&client_addr, 0, sizeof(struct sockaddr_in));
	
	/* Create BIO */
	bio = BIO_new_dgram(fd, BIO_NOCLOSE);
	
	ssl = SSL_new(ctx);
	
	SSL_set_bio(ssl, bio, bio);
	SSL_set_options(ssl, SSL_OP_COOKIE_EXCHANGE);
	

	while (1) 
	{
	         int ret;

		 /* Watch stdin (fd 0) to see when it has input. */
		 FD_ZERO(&rfds);
		 FD_SET(fd, &rfds);
		 
		 tv.tv_sec = 1;
		 tv.tv_usec = 0;
		 
		 
		 retval = select(fd+1, &rfds, NULL, NULL, &tv);
		 /* Don't rely on the value of tv now! */
		 
		 if (retval == -1)
		   perror("select()");
		 else if (retval)
		   {
		     printf("Listen Data is available now.\n");
		 
		     if ((ret = DTLSv1_listen(ssl, &client_addr)) < 0)
		       {
			 printf("Listen thread 1In progress Listen ret code = %d\n", ret);
		       }
		     else if (ret == 0)
		       {
			 printf("Listen thread Error Listen ret code = %d\n", ret);
		       } 
		     else if (ret == 1)
		       {
		    
			 info = (struct pass_info*) malloc (sizeof(struct pass_info));
			 memcpy(&info->server_addr, &server_addr, sizeof(struct sockaddr_in));
			 memcpy(&info->client_addr, &client_addr, sizeof(struct sockaddr_in));
			 info->ssl = ssl;
			 memset(&client_addr, 0, sizeof(struct sockaddr_in));
			 
			 /* Create BIO */
			 bio = BIO_new_dgram(fd, BIO_NOCLOSE);
			 
			 ssl = SSL_new(ctx);
			 
			 SSL_set_bio(ssl, bio, bio);
			 SSL_set_options(ssl, SSL_OP_COOKIE_EXCHANGE);
			 
			 if (connection_handle(info) != 0) {
			   perror("unable to handle connection");
			   exit(-1);
			 }

		       }
		     else
		       printf("unhandled Error Listen ret code = %d\n", ret);
		   }

		   else
		     printf("Listen  No data within 1 second.\n");

		connection_accept();
		connection_read();

	}

	
}



int main(int argc, char **argv)
{
	int port = 23232;
	int length = 100;
	int messagenumber = 5;
	char local_addr[32];

	//port = atoi(argv[2]);

	memset(local_addr, 0, 32);

	NetworkStart();
	start_server(port, local_addr);	
	
	NetworkTerm();

	return 0;
}
