This is an automated email from the ASF dual-hosted git repository.

hutcheb pushed a commit to branch fix/PLC4X-303
in repository https://gitbox.apache.org/repos/asf/plc4x.git


The following commit(s) were added to refs/heads/fix/PLC4X-303 by this push:
     new eae36a9  Should now work, need to add tests next.
eae36a9 is described below

commit eae36a9181da211fd70047dbaf729dd2064d4c82
Author: hutcheb <[email protected]>
AuthorDate: Fri Nov 5 04:11:13 2021 +1000

    Should now work, need to add tests next.
---
 .../plc4x/java/opcua/context/SecureChannel.java    | 122 +++++++++++++++------
 1 file changed, 87 insertions(+), 35 deletions(-)

diff --git 
a/plc4j/drivers/opcua/src/main/java/org/apache/plc4x/java/opcua/context/SecureChannel.java
 
b/plc4j/drivers/opcua/src/main/java/org/apache/plc4x/java/opcua/context/SecureChannel.java
index 1aa9cc8..9daf885 100644
--- 
a/plc4j/drivers/opcua/src/main/java/org/apache/plc4x/java/opcua/context/SecureChannel.java
+++ 
b/plc4j/drivers/opcua/src/main/java/org/apache/plc4x/java/opcua/context/SecureChannel.java
@@ -43,6 +43,7 @@ import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 import java.security.cert.CertificateEncodingException;
 import java.time.Duration;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.LinkedList;
 import java.util.List;
@@ -52,6 +53,9 @@ import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Stream;
 
 public class SecureChannel {
 
@@ -82,6 +86,15 @@ public class SecureChannel {
         new ExtensionObjectEncodingMask(false, false, false),
         new NullExtension());               // Body
 
+    public static final Pattern INET_ADDRESS_PATTERN = 
Pattern.compile("(.(?<transportCode>tcp))?://" +
+        "(?<transportHost>[\\w.-]+)(:" +
+        "(?<transportPort>\\d*))?");
+
+    public static final Pattern URI_PATTERN = 
Pattern.compile("^(?<protocolCode>opc)" +
+        INET_ADDRESS_PATTERN +
+        "(?<transportEndpoint>[\\w/=]*)[\\?]?"
+    );
+
     private static final long EPOCH_OFFSET = 116444736000000000L;         
//Offset between OPC UA epoch time and linux epoch time.
     private static final PascalString APPLICATION_URI = new 
PascalString("urn:apache:plc4x:client");
     private static final PascalString PRODUCT_URI = new 
PascalString("urn:apache:plc4x:client");
@@ -91,6 +104,7 @@ public class SecureChannel {
     private final byte[] clientNonce = RandomUtils.nextBytes(40);
     private AtomicInteger requestHandleGenerator = new AtomicInteger(1);
     private PascalString policyId;
+    private UserTokenType tokenType;
     private PascalString endpoint;
     private boolean discovery;
     private String username;
@@ -118,7 +132,8 @@ public class SecureChannel {
     private CompletableFuture<Void> keepAlive;
     private int sendBufferSize;
     private int maxMessageSize;
-    private String[] endpoints = new String[3];
+    private List<String> endpoints = new ArrayList<>();
+
     private AtomicLong senderSequenceNumber = new AtomicLong();
 
     public SecureChannel(DriverContext driverContext, OpcuaConfiguration 
configuration) {
@@ -155,9 +170,9 @@ public class SecureChannel {
         // Generate a list of endpoints we can use.
         try {
             InetAddress address = 
InetAddress.getByName(this.configuration.getHost());
-            this.endpoints[0] = "opc.tcp://" + address.getHostAddress() + ":" 
+ configuration.getPort() +  configuration.getTransportEndpoint();
-            this.endpoints[1] = "opc.tcp://" + address.getHostName() + ":" + 
configuration.getPort() +  configuration.getTransportEndpoint();
-            this.endpoints[2] = "opc.tcp://" + address.getCanonicalHostName() 
+ ":" + configuration.getPort() +  configuration.getTransportEndpoint();
+            this.endpoints.add(address.getHostAddress());
+            this.endpoints.add(address.getHostName());
+            this.endpoints.add(address.getCanonicalHostName());
         } catch (UnknownHostException e) {
             e.printStackTrace();
         }
@@ -487,7 +502,6 @@ public class SecureChannel {
         senderCertificate = 
sessionResponse.getServerCertificate().getStringValue();
         
encryptionHandler.setServerCertificate(EncryptionHandler.getCertificateX509(senderCertificate));
         this.senderNonce = sessionResponse.getServerNonce().getStringValue();
-        UserTokenType tokenType = UserTokenType.userTokenTypeAnonymous;
         String[] endpoints = new String[3];
         try {
             InetAddress address = 
InetAddress.getByName(this.configuration.getHost());
@@ -504,7 +518,7 @@ public class SecureChannel {
             throw new PlcRuntimeException("Unable to find endpoint - " + 
endpoints[1]);
         }
 
-        ExtensionObject userIdentityToken = getIdentityToken(tokenType, 
policyId.getStringValue());
+        ExtensionObject userIdentityToken = getIdentityToken(this.tokenType, 
policyId.getStringValue());
 
         int requestHandle = getRequestHandle();
 
@@ -1155,40 +1169,77 @@ public class SecureChannel {
         return this.tokenId.get();
     }
 
-    private void selectEndpoint(CreateSessionResponse sessionResponse) {
+    /**
+     * Selects the endpoint to use based on the connection string provided.
+     * If Discovery is disabled it will use the host address return from the 
server
+     * @param sessionResponse - The CreateSessionResponse message returned by 
the server
+     * @throws PlcRuntimeException - If no endpoint with a compatible policy 
is found raise and error.
+     */
+    private void selectEndpoint(CreateSessionResponse sessionResponse) throws 
PlcRuntimeException {
         List<String> returnedEndpoints = new LinkedList<String>();
-        for (ExtensionObjectDefinition extensionObject : 
sessionResponse.getServerEndpoints()) {
-            EndpointDescription endpointDescription = (EndpointDescription) 
extensionObject;
-            
returnedEndpoints.add(endpointDescription.getEndpointUrl().getStringValue());
+
+        // Get a list of the endpoints which match ours.
+        Stream<EndpointDescription> filteredEndpoints = 
Arrays.stream(Arrays.copyOf(sessionResponse.getServerEndpoints(), 
sessionResponse.getServerEndpoints().length, EndpointDescription[].class))
+            .filter(this::isEndpoint);
+
+        //Determine if the requested security policy is included in the 
endpoint
+        filteredEndpoints.forEach(endpoint -> 
hasIdentity(Arrays.copyOf(endpoint.getUserIdentityTokens(), 
endpoint.getUserIdentityTokens().length, UserTokenPolicy[].class)));
+
+        if (this.policyId == null) {
+            throw new PlcRuntimeException("Unable to find endpoint - " + 
this.endpoints.get(0));
+        }
+        if (this.tokenType == null) {
+            throw new PlcRuntimeException("Unable to find Security Policy for 
endpoint - " + this.endpoints.get(0));
         }
+    }
 
-        List<EndpointDescription> filteredEndpoints = 
Arrays.stream(sessionResponse.getServerEndpoints())
-                                                            .filter(endpoint 
-> isEndpoint((EndpointDescription) 
endpoint).getEndpointUrl().getStringValue().equals(hostEndpoints));
-
-
-
-        for (String hostEndpoints : endpoints) {
-            for (ExtensionObjectDefinition extensionObject : 
sessionResponse.getServerEndpoints()) {
-                EndpointDescription endpointDescription = 
(EndpointDescription) extensionObject;
-                if 
(endpointDescription.getEndpointUrl().getStringValue().equals(hostEndpoints)) {
-                    for (ExtensionObjectDefinition userTokenCast : 
endpointDescription.getUserIdentityTokens()) {
-                        UserTokenPolicy identityToken = (UserTokenPolicy) 
userTokenCast;
-                        if ((identityToken.getTokenType() == 
UserTokenType.userTokenTypeAnonymous) && (this.username == null)) {
-                            LOGGER.info("Using Endpoint {} with security {}", 
endpointDescription.getEndpointUrl().getStringValue(), 
identityToken.getPolicyId().getStringValue());
-                            policyId = identityToken.getPolicyId();
-                            tokenType = identityToken.getTokenType();
-                        } else if ((identityToken.getTokenType() == 
UserTokenType.userTokenTypeUserName) && (this.username != null)) {
-                            LOGGER.info("Using Endpoint {} with security {}", 
endpointDescription.getEndpointUrl().getStringValue(), 
identityToken.getPolicyId().getStringValue());
-                            policyId = identityToken.getPolicyId();
-                            tokenType = identityToken.getTokenType();
-                        }
-                    }
-                }
-            }
+    /**
+     * Checks each component of the return endpoint description against the 
connection string.
+     * If all are correct then return true.
+     * @param endpoint - EndpointDescription returned from server
+     * @return true if this endpoint matches our configuration
+     * @throws PlcRuntimeException - If the returned endpoint string doesn't 
match the format expected
+     */
+    public boolean isEndpoint(EndpointDescription endpoint) throws 
PlcRuntimeException {
+        // Split up the connection string into it's individual segments.
+        Matcher matcher = 
URI_PATTERN.matcher(endpoint.getEndpointUrl().getStringValue());
+        if (!matcher.matches()) {
+            throw new PlcRuntimeException(
+                "Endpoint returned from the server doesn't match the format 
'{protocol-code}:({transport-code})?//{transport-host}(:{transport-port})(/{transport-endpoint})'");
+        }
+        LOGGER.trace("Using Endpoint {} {} {}", 
matcher.group("transportHost"), matcher.group("transportPort"), 
matcher.group("transportEndpoint"));
+        if (this.configuration.isDiscovery() && 
!this.endpoints.contains(matcher.group("transportHost"))) {
+            return false;
         }
 
-        if (this.policyId == null) {
-            throw new PlcRuntimeException("Unable to find endpoint - " + 
endpoints[1]);
+        if 
(!this.configuration.getPort().equals(matcher.group("transportPort"))) {
+            return false;
+        }
+
+        if 
(!this.configuration.getTransportEndpoint().equals(matcher.group("transportEndpoint")))
 {
+            return false;
+        }
+
+        if (!this.configuration.isDiscovery()) {
+            this.configuration.setHost(matcher.group("transportHost"));
+        }
+
+        return true;
+    }
+
+    /**
+     *
+     * @param policies
+     */
+    private void hasIdentity(UserTokenPolicy[] policies) {
+        for (UserTokenPolicy identityToken : policies) {
+            if ((identityToken.getTokenType() == 
UserTokenType.userTokenTypeAnonymous) && (this.username == null)) {
+                policyId = identityToken.getPolicyId();
+                tokenType = identityToken.getTokenType();
+            } else if ((identityToken.getTokenType() == 
UserTokenType.userTokenTypeUserName) && (this.username != null)) {
+                policyId = identityToken.getPolicyId();
+                tokenType = identityToken.getTokenType();
+            }
         }
     }
 
@@ -1251,4 +1302,5 @@ public class SecureChannel {
     public static long getCurrentDateTime() {
         return (System.currentTimeMillis() * 10000) + EPOCH_OFFSET;
     }
+
 }

Reply via email to