
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.security.Key;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.OutputKeys;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerConfigurationException;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.TransformerFactoryConfigurationError;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;

import org.apache.xml.security.Init;
import org.apache.xml.security.c14n.Canonicalizer;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.signature.XMLSignature;
import org.apache.xml.security.signature.XMLSignatureException;
import org.apache.xml.security.transforms.TransformationException;
import org.apache.xml.security.transforms.Transforms;
import org.apache.xml.security.utils.Constants;
import org.w3c.dom.DOMImplementation;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.w3c.dom.ls.DOMImplementationLS;
import org.w3c.dom.ls.LSSerializer;
import org.xml.sax.SAXException;


public class SignatureNamespaceTest {

    public static void main(String[] args) {
        Init.init();
        
        KeyPair rsaKeyPair = generateKeyPair("RSA", 1024);
        Key signingKey = rsaKeyPair.getPrivate();
        String signingAlgo = XMLSignature.ALGO_ID_SIGNATURE_RSA_SHA1;
        
        Document document = null;
        try {
            document = parseStream(new FileInputStream( new File("assertion.xml")) );
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        
        XMLSignature sig = null;
        try {
            sig = new XMLSignature(document, "", signingAlgo, Canonicalizer.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
        } catch (XMLSecurityException e) {
            e.printStackTrace();
        }
        
        Element assertion = document.getDocumentElement();
        NodeList nodelist = assertion.getElementsByTagNameNS("urn:oasis:names:tc:SAML:2.0:assertion", "Subject");
        Element subject = (Element) nodelist.item(0);
        assertion.insertBefore(sig.getElement(), subject);
        
        Transforms transforms = new Transforms(document);
        try {
            transforms.addTransform(Transforms.TRANSFORM_ENVELOPED_SIGNATURE);
            transforms.addTransform(Transforms.TRANSFORM_C14N_EXCL_OMIT_COMMENTS);
        } catch (TransformationException e) {
            e.printStackTrace();
        }
        
        try {
            sig.addDocument("#assertionID", transforms, Constants.ALGO_ID_DIGEST_SHA1);
        } catch (XMLSignatureException e) {
            e.printStackTrace();
        }
        
        try {
            sig.sign(signingKey);
        } catch (XMLSignatureException e) {
            e.printStackTrace();
        }
        
        System.out.println(nodeToString(document));
        //prettyPrintXML(document);
    }

    public static String nodeToString(Node node) {
        DOMImplementation domImpl = null;
        if (node.getNodeType() == Node.DOCUMENT_NODE) {
            domImpl = ((Document) node).getImplementation();
        } else {
            domImpl = node.getOwnerDocument().getImplementation();
        }
        DOMImplementationLS domImplLS = (DOMImplementationLS) domImpl.getFeature("LS", "3.0");
        LSSerializer serializer = domImplLS.createLSSerializer();
        return serializer.writeToString(node);
    }
    
    public static void prettyPrintXML(Node node) {
        Transformer tr = null;
        try {
            tr = TransformerFactory.newInstance().newTransformer();
        } catch (TransformerConfigurationException e) {
            e.printStackTrace();
        } catch (TransformerFactoryConfigurationError e) {
            e.printStackTrace();
        }
        tr.setOutputProperty(OutputKeys.METHOD,"xml");
        tr.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
        tr.setOutputProperty(OutputKeys.INDENT, "yes");
        tr.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "3");

        try {
            tr.transform( new DOMSource(node),new StreamResult(System.out));
        } catch (TransformerException e) {
            e.printStackTrace();
        }
        
    }
    
    public static Document parseStream(InputStream input) {
        
        DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
        dbf.setNamespaceAware(true);
        
        DocumentBuilder builder = null;
        try {
            builder = dbf.newDocumentBuilder();
        } catch (ParserConfigurationException e) {
            e.printStackTrace();
            System.exit(1);
        }
        
        Document document = null;
        try {
            document = builder.parse(input);
        } catch (SAXException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        
        return document;
    }
    
    public static KeyPair generateKeyPair(String algo, int keyLength) {
        KeyPair keyPair = null;
        KeyPairGenerator keyGenerator = null;
        try {
            keyGenerator = KeyPairGenerator.getInstance(algo);
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        keyGenerator.initialize(keyLength);
        keyPair = keyGenerator.generateKeyPair();
        return keyPair;
    }

}
