/*
 * Copyright (C) The Apache Software Foundation. All rights reserved.
 *
 * This software is published under the terms of the Apache Software License
 * version 1.1, a copy of which has been included with this distribution in
 * the LICENSE.txt file.
 */
package org.apache.avalon.cornerstone.blocks.sockets;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;
import java.net.UnknownHostException;
import java.security.KeyStore;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.security.cert.X509Certificate;
import org.apache.avalon.cornerstone.services.sockets.SocketFactory;
import org.apache.avalon.framework.activity.Initializable;
import org.apache.avalon.framework.configuration.Configurable;
import org.apache.avalon.framework.configuration.Configuration;
import org.apache.avalon.framework.configuration.ConfigurationException;
import org.apache.avalon.framework.context.Context;
import org.apache.avalon.framework.context.Contextualizable;
import org.apache.avalon.framework.logger.AbstractLogEnabled;
import org.apache.avalon.phoenix.BlockContext;
import org.apache.avalon.framework.container.ContainerUtil;

/**
 * Manufactures TLS client sockets. Configuration element inside a
 * SocketManager would look like:
 * <pre>
 *  &lt;factory name="secure"
 *            class="org.apache.avalon.cornerstone.blocks.sockets.TLSSocketFactory" &gt;
 *   &lt;ssl-context/ &gt;
 *   &lt;timeout&gt; 0 &lt;/timeout&gt; &lt;!-- if the value is greater than
 *   zero, a read() call on the InputStream associated
 *   with this Socket will block for only this amount of time in
 *   milliseconds. Default value is 0. --&gt;
 *   &lt;verify-server-identity&gt;true|false&lt;/verify-server-identity&gt;
 *   &lt;!-- whether or not the server identity should be verified. Defaults to false. --&gt;
 * &lt;/factory&gt;
 * </pre>
 * <p>
 * <tt>ssl-context</tt> contents is described separately in {@link SSLFactoryBuilder}.
 * </p>
 * <p>
 * Server identity verification currently includes only comparing the
 * certificate Common Name received with the host name in the
 * passed address. Indentity verification requires that SSL
 * handshake is completed for the socket, so it takes longer
 * to get a verified socket (and won't play well with non-blocking
 * application like SEDA).
 * </p>
 * <p>
 * Another thing to keep in mind when using identity verification is
 * that <tt>InetAddress</tt> objects for the remote hosts should be
 * built using {@link java.net.InetAddress#getByName} with
 * the host name (matching the certificate CN) as the
 * argument. Failure to do so may cause relatively costly DNS lookups
 * and false rejections caused by inconsistencies between forward and
 * reverse resolution.
 * </p>
 *
 * @author <a href="mailto:peter at apache.org">Peter Donald</a>
 * @author <a href="mailto:fede@apache.org">Federico Barbieri</a>
 * @author <a href="mailto:charles@benett1.demon.co.uk">Charles Benett</a>
 * @author <a href="mailto:">Harish Prabandham</a>
 * @author <a href="mailto:">Costin Manolache</a>
 * @author <a href="mailto:">Craig McClanahan</a>
 * @author <a href="mailto:myfam@surfeu.fi">Andrei Ivanov</a>
 *
 */
public class TLSSocketFactory
    extends AbstractLogEnabled
    implements SocketFactory, Contextualizable, Configurable, Initializable
{
    private SSLSocketFactory m_factory;

    private final static int WAIT_FOREVER = 0;
    private int m_socketTimeOut;
    private boolean m_verifyServerIdentity;

    private Context m_context;
    private Configuration m_childConfig;

    public void contextualize( final Context context )
    {
        m_context = context;
    }

    /**
     * Configures the factory.
     *
     * @param configuration the Configuration
     * @exception ConfigurationException if an error occurs
     */
    public void configure( final Configuration configuration )
        throws ConfigurationException
    {
        m_socketTimeOut = configuration.getChild( "timeout" ).getValueAsInteger( WAIT_FOREVER );
        m_verifyServerIdentity = configuration.getChild( "verify-server-identity" ).getValueAsBoolean( false );
        m_childConfig = configuration.getChild( "ssl-context" );
        if ( m_childConfig == null ) {
            getLogger().warn( "ssl-context child not found, please" +
                              " update your configuration according to" +
                              " documentation. Reverting to the" +
                              " old configuration format." );
            // not completely compatible though
            m_childConfig = configuration;
        }
    }

    /**
     * Creates an SSL factory using the confuration values.
     */
    public void initialize() throws Exception
    {
        SSLFactoryBuilder builder = new SSLFactoryBuilder();
        setupLogger( builder );
        ContainerUtil.contextualize( builder, m_context );
        ContainerUtil.configure( builder, m_childConfig );

        m_factory = builder.buildFactory();

        ContainerUtil.shutdown( builder );
    }

    /**
     * Does the unconditional part of socket initialization that
     * applies to all Sockets.
     */
    private Socket initSocket( final Socket socket )
        throws IOException
    {
        socket.setSoTimeout( m_socketTimeOut );
        return socket;
    }

    private Socket sslWrap(Socket bareSocket, InetAddress address, int port)
        throws IOException
    {
        String hostName = address.getHostName();
        SSLSocket sslSocket = (SSLSocket)
            m_factory.createSocket( bareSocket, hostName, port, true );
        sslSocket.startHandshake();
        SSLSession session = sslSocket.getSession();
        String DN = 
            session.getPeerCertificateChain()[0]
            .getSubjectDN().getName();
        String CN = getCN(DN);
        if (! hostName.equals(CN)) {
            throw new IOException("Host name mismatch, expecting " + 
                                  hostName + " DN is " + DN);
        }
        if (getLogger().isDebugEnabled()) {
            getLogger().debug("DN of the server " + DN);
            byte [] sessionId = session.getId();
            StringBuffer tmp = new StringBuffer("Session id ");
            for (int i = 0; i < sessionId.length; i++) {
                byte signedValue = sessionId[i];
                int unsignedByteValue = (signedValue >= 0) ? signedValue : 256 + signedValue;
                tmp.append(Integer.toHexString(unsignedByteValue)).append(':');
            }
            getLogger().debug(tmp.toString());
        }
        return sslSocket;
    }

    /**
     * Finds the Common Name specified from the given Distinguished
     * Name. Normally CN is the first part of the DN.
     *
     * @return the common name or null if DN is malformed
     */
    private String getCN(String DN) {
        int startOfCN = DN.indexOf("CN=");
        if (startOfCN < 0) {
            return null;
        }
        int startOfHostName = startOfCN + "CN=".length();
        int endOfHostName = DN.indexOf(',', startOfHostName);
        if (endOfHostName > 0) {
            return DN.substring(startOfHostName, endOfHostName);
        } else {
            return null;
        }
    }

    /**
     * Creates a socket connected to the specified remote address.
     *
     * @param address the remote address
     * @param port the remote port
     * @return the socket
     * @exception IOException if an error occurs
     */
    public Socket createSocket( InetAddress address, int port ) throws IOException
    {
        // Uses 2 different approaches to socket construction, because
        // sslWrap depends on wrapping createSocket which in turn
        // requires that address be resoved to the host name.
        //
        if (m_verifyServerIdentity) {
            return sslWrap( initSocket( new Socket( address, port )),
                            address, port );
        } else {
            return initSocket( m_factory.createSocket( address, port ));
        }
    }

    /**
     * Creates a socket and connected to the specified remote address
     * originating from specified local address.
     *
     * @param address the remote address
     * @param port the remote port
     * @param localAddress the local address
     * @param localPort the local port
     * @return the socket
     * @exception IOException if an error occurs
     */
    public Socket createSocket( InetAddress address, int port, InetAddress localAddress, int localPort ) throws IOException
    {
        return sslWrap( initSocket( new Socket( address, port,
                                                localAddress, localPort )),
                        address, port );
        
    }

}

