package com.test.security;

import java.util.Iterator;
import java.util.Map;

import javax.xml.namespace.QName;
import javax.xml.soap.SOAPException;
import javax.xml.soap.SOAPHeader;
import javax.xml.soap.SOAPMessage;

import org.apache.cxf.binding.soap.SoapMessage;
import org.apache.cxf.binding.soap.saaj.SAAJInInterceptor;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.ws.security.wss4j.WSS4JInInterceptor;
import org.apache.ws.security.WSConstants;
import org.apache.ws.security.handler.WSHandlerConstants;

import org.springframework.util.Assert;

import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

public class DynamicEncryptionInInterceptor extends WSS4JInInterceptor
{
    private SAAJInInterceptor saaj = new SAAJInInterceptor();
    private String defaultActions;

    public DynamicEncryptionInInterceptor(Map<String, Object> properties)
    {
        super(properties);
    }

    @Override
    public void handleMessage(SoapMessage mc) throws Fault
    {
        if (isKeyDefined(mc))
        {
            mc.put(WSHandlerConstants.ACTION, defaultActions + ' ' + WSHandlerConstants.ENCRYPT);
        }
        else
        {
            mc.put(WSHandlerConstants.ACTION, defaultActions);
        }
        super.handleMessage(mc);
    }

    /**
     * @return true if encryption information found in WS-Security header.
     */
    boolean isKeyDefined(SoapMessage mc)
    {
        SOAPMessage soap = getSOAPMessage(mc);
        SOAPHeader header = null;
        try
        {
            header = soap.getSOAPHeader();
        }
        catch (SOAPException ex)
        {
            Assert.state(false, "Failure when trying to access SOAP header");
        }

        Iterator<Node> it = header.getChildElements(new QName(WSConstants.WSSE_NS, WSConstants.WSSE_LN));
        if (it == null || !it.hasNext())
        {
            // no security header
            return false;
        }
        Node securityHeader = it.next();
        // iterate over parts of security header, and return true if XML-ENC found
        NodeList nodeList = securityHeader.getChildNodes();
        for (int i=0; i<nodeList.getLength(); i++)
        {
            Node securitySubnode = nodeList.item(i);
            if (WSConstants.ENC_NS.equals(securitySubnode.getNamespaceURI()) &&
                WSConstants.ENC_KEY_LN.equals(securitySubnode.getLocalName()))
            {
                return true;
            }
        }
        return false;
    }

    /**
     * Copy-paste of private method of WSS4JInInterceptor.
     */
    SOAPMessage getSOAPMessage(SoapMessage msg)
    {
        SOAPMessage doc = msg.getContent(SOAPMessage.class);
        if (doc == null)
        {
            saaj.handleMessage(msg);
            doc = msg.getContent(SOAPMessage.class);
        }
        return doc;
    }

    public void setDefaultActions(String defaultActions)
    {
        this.defaultActions = defaultActions;
    }
}
