This is an automated email from the ASF dual-hosted git repository.
terrymanu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 786ac3d6e95 Add session attribution context propagation (#38745)
786ac3d6e95 is described below
commit 786ac3d6e9587c1576bbcd5f83aa8610ded7c816
Author: Liang Zhang <[email protected]>
AuthorDate: Thu May 28 17:16:27 2026 +0800
Add session attribution context propagation (#38745)
* Add session attribution context propagation
Introduce trusted HTTP session attribution source configuration and
resolver,
bind attribution to MCP sessions, and propagate it through request scope for
handler access. Update runtime logging and safety policy wording, and add
regression coverage for configuration, validator, servlet, and session
lifecycle behavior.
* Add session attribution context propagation
Introduce trusted HTTP session attribution source configuration and
resolver,
bind attribution to MCP sessions, and propagate it through request scope for
handler access. Update runtime logging and safety policy wording, and add
regression coverage for configuration, validator, servlet, and session
lifecycle behavior.
* Add session attribution context propagation
Introduce trusted HTTP session attribution source configuration and
resolver,
bind attribution to MCP sessions, and propagate it through request scope for
handler access. Update runtime logging and safety policy wording, and add
regression coverage for configuration, validator, servlet, and session
lifecycle behavior.
---
.../shardingsphere/mcp/api/MCPHandlerContext.java | 13 ++
.../mcp/api/session/MCPSessionAttribution.java} | 16 ++-
.../mcp/bootstrap/MCPRuntimeLauncher.java | 6 +-
.../config/HttpTransportConfiguration.java | 15 ++-
... => SessionAttributionSourceConfiguration.java} | 10 +-
.../config/YamlHttpTransportConfiguration.java | 2 +
...YamlSessionAttributionSourceConfiguration.java} | 16 +--
.../YamlHttpTransportConfigurationSwapper.java | 30 ++++-
.../HttpTransportConfigurationValidator.java | 48 ++++++++
.../server/http/SessionAttributionResolver.java | 134 +++++++++++++++++++++
.../server/http/StreamableHttpMCPServlet.java | 59 ++++++---
.../ServerTransportSecurityValidatorFactory.java | 6 +-
...dingSphereServerTransportSecurityValidator.java | 24 ++++
.../YamlHttpTransportConfigurationSwapperTest.java | 56 +++++++++
.../http/SessionAttributionResolverTest.java | 53 ++++++++
.../server/http/StreamableHttpMCPServletTest.java | 27 +++++
...erverTransportSecurityValidatorFactoryTest.java | 21 ++--
...SphereServerTransportSecurityValidatorTest.java | 100 +++++----------
.../mcp/core/completion/MCPCompletionService.java | 2 +-
.../mcp/core/context/MCPRequestScope.java | 21 ++++
.../mcp/core/session/MCPSessionManager.java | 28 +++++
.../mcp/core/tool/MCPToolController.java | 2 +-
.../mcp/core/context/MCPRequestScopeTest.java | 52 ++++++++
.../capability/ServerCapabilitiesHandlerTest.java | 2 +-
.../mcp/core/session/MCPSessionManagerTest.java | 30 +++++
.../support/security/MCPClientSafetyPolicy.java | 5 +-
.../MCPModelFirstContractPayloadBuilderTest.java | 2 +-
.../security/MCPClientSafetyPolicyTest.java | 1 +
28 files changed, 659 insertions(+), 122 deletions(-)
diff --git
a/mcp/api/src/main/java/org/apache/shardingsphere/mcp/api/MCPHandlerContext.java
b/mcp/api/src/main/java/org/apache/shardingsphere/mcp/api/MCPHandlerContext.java
index 77168f5e886..a78f7ebecb0 100644
---
a/mcp/api/src/main/java/org/apache/shardingsphere/mcp/api/MCPHandlerContext.java
+++
b/mcp/api/src/main/java/org/apache/shardingsphere/mcp/api/MCPHandlerContext.java
@@ -17,8 +17,21 @@
package org.apache.shardingsphere.mcp.api;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
+
+import java.util.Optional;
+
/**
* Marker interface for MCP handler execution context.
*/
public interface MCPHandlerContext {
+
+ /**
+ * Find session attribution.
+ *
+ * @return session attribution
+ */
+ default Optional<MCPSessionAttribution> findSessionAttribution() {
+ return Optional.empty();
+ }
}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
b/mcp/api/src/main/java/org/apache/shardingsphere/mcp/api/session/MCPSessionAttribution.java
similarity index 74%
copy from
mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
copy to
mcp/api/src/main/java/org/apache/shardingsphere/mcp/api/session/MCPSessionAttribution.java
index 0fa500b6a59..8e04da25e2a 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
+++
b/mcp/api/src/main/java/org/apache/shardingsphere/mcp/api/session/MCPSessionAttribution.java
@@ -15,21 +15,25 @@
* limitations under the License.
*/
-package org.apache.shardingsphere.mcp.bootstrap.config;
+package org.apache.shardingsphere.mcp.api.session;
+import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
+import java.util.Map;
+
/**
- * HTTP transport configuration.
+ * MCP session attribution.
*/
@RequiredArgsConstructor
@Getter
-public final class HttpTransportConfiguration {
+@EqualsAndHashCode
+public final class MCPSessionAttribution {
- private final String bindHost;
+ private final String subject;
- private final int port;
+ private final String source;
- private final String endpointPath;
+ private final Map<String, String> attributes;
}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/MCPRuntimeLauncher.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/MCPRuntimeLauncher.java
index e3de611f27c..71a6d03221e 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/MCPRuntimeLauncher.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/MCPRuntimeLauncher.java
@@ -25,6 +25,7 @@ import
org.apache.shardingsphere.mcp.bootstrap.config.MCPTransportType;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.MCPRuntimeServer;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.StreamableHttpMCPServer;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.stdio.StdioMCPServer;
+import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.SessionAttributionResolver;
import org.apache.shardingsphere.mcp.core.context.MCPRuntimeContext;
import org.apache.shardingsphere.mcp.core.session.MCPSessionManager;
import
org.apache.shardingsphere.mcp.support.database.capability.MCPDatabaseCapabilityProvider;
@@ -72,8 +73,9 @@ public final class MCPRuntimeLauncher {
private List<String> createHttpStartupLogMessages(final
MCPLaunchConfiguration config, final MCPRuntimeServer server) {
int port = server instanceof StreamableHttpMCPServer ?
((StreamableHttpMCPServer) server).getLocalPort() :
config.getHttpTransport().getPort();
String endpoint = String.format("http://%s:%d%s",
config.getHttpTransport().getBindHost(), port,
config.getHttpTransport().getEndpointPath());
- return List.of(String.format("ShardingSphere MCP runtime started,
transport=http, config=%s, databases=%d, endpoint=%s, authorization=none,
logs=%s.",
- configPath, config.getDatabases().size(), endpoint, LOG_PATH));
+ SessionAttributionResolver sessionAttributionResolver = new
SessionAttributionResolver(config.getHttpTransport().getSessionAttributionSource());
+ return List.of(String.format("ShardingSphere MCP runtime started,
transport=http, config=%s, databases=%d, endpoint=%s, session_attribution=%s,
logs=%s.",
+ configPath, config.getDatabases().size(), endpoint,
sessionAttributionResolver.getSummary(), LOG_PATH));
}
private List<String> createStdioStartupLogMessages(final
MCPLaunchConfiguration config) {
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
index 0fa500b6a59..f80f3af34c3 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
@@ -18,12 +18,10 @@
package org.apache.shardingsphere.mcp.bootstrap.config;
import lombok.Getter;
-import lombok.RequiredArgsConstructor;
/**
* HTTP transport configuration.
*/
-@RequiredArgsConstructor
@Getter
public final class HttpTransportConfiguration {
@@ -32,4 +30,17 @@ public final class HttpTransportConfiguration {
private final int port;
private final String endpointPath;
+
+ private final SessionAttributionSourceConfiguration
sessionAttributionSource;
+
+ public HttpTransportConfiguration(final String bindHost, final int port,
final String endpointPath) {
+ this(bindHost, port, endpointPath, null);
+ }
+
+ public HttpTransportConfiguration(final String bindHost, final int port,
final String endpointPath, final SessionAttributionSourceConfiguration
sessionAttributionSource) {
+ this.bindHost = bindHost;
+ this.port = port;
+ this.endpointPath = endpointPath;
+ this.sessionAttributionSource = sessionAttributionSource;
+ }
}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/SessionAttributionSourceConfiguration.java
similarity index 80%
copy from
mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
copy to
mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/SessionAttributionSourceConfiguration.java
index 0fa500b6a59..45f5526aa33 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/SessionAttributionSourceConfiguration.java
@@ -21,15 +21,15 @@ import lombok.Getter;
import lombok.RequiredArgsConstructor;
/**
- * HTTP transport configuration.
+ * Session attribution source configuration.
*/
@RequiredArgsConstructor
@Getter
-public final class HttpTransportConfiguration {
+public final class SessionAttributionSourceConfiguration {
- private final String bindHost;
+ private final String subjectHeader;
- private final int port;
+ private final String sourceHeader;
- private final String endpointPath;
+ private final String attributeHeaderPrefix;
}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/config/YamlHttpTransportConfiguration.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/config/YamlHttpTransportConfiguration.java
index 8c6c8723944..559b15394ab 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/config/YamlHttpTransportConfiguration.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/config/YamlHttpTransportConfiguration.java
@@ -35,4 +35,6 @@ public final class YamlHttpTransportConfiguration implements
YamlConfiguration {
private Integer port;
private String endpointPath;
+
+ private YamlSessionAttributionSourceConfiguration sessionAttributionSource;
}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/config/YamlSessionAttributionSourceConfiguration.java
similarity index 72%
copy from
mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
copy to
mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/config/YamlSessionAttributionSourceConfiguration.java
index 0fa500b6a59..49fe9588a1a 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/HttpTransportConfiguration.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/config/YamlSessionAttributionSourceConfiguration.java
@@ -15,21 +15,21 @@
* limitations under the License.
*/
-package org.apache.shardingsphere.mcp.bootstrap.config;
+package org.apache.shardingsphere.mcp.bootstrap.config.yaml.config;
import lombok.Getter;
-import lombok.RequiredArgsConstructor;
+import lombok.Setter;
/**
- * HTTP transport configuration.
+ * YAML session attribution source configuration.
*/
-@RequiredArgsConstructor
@Getter
-public final class HttpTransportConfiguration {
+@Setter
+public final class YamlSessionAttributionSourceConfiguration {
- private final String bindHost;
+ private String subjectHeader;
- private final int port;
+ private String sourceHeader;
- private final String endpointPath;
+ private String attributeHeaderPrefix;
}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapper.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapper.java
index 3b3635f1be8..0cd6653b5ad 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapper.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapper.java
@@ -17,9 +17,11 @@
package org.apache.shardingsphere.mcp.bootstrap.config.yaml.swapper;
+import
org.apache.shardingsphere.mcp.bootstrap.config.SessionAttributionSourceConfiguration;
import
org.apache.shardingsphere.infra.util.yaml.swapper.YamlConfigurationSwapper;
import
org.apache.shardingsphere.mcp.bootstrap.config.HttpTransportConfiguration;
import
org.apache.shardingsphere.mcp.bootstrap.config.yaml.config.YamlHttpTransportConfiguration;
+import
org.apache.shardingsphere.mcp.bootstrap.config.yaml.config.YamlSessionAttributionSourceConfiguration;
import
org.apache.shardingsphere.mcp.support.yaml.MCPYamlConfigurationValidator;
import java.util.Objects;
@@ -34,12 +36,30 @@ public final class YamlHttpTransportConfigurationSwapper
implements YamlConfigur
private static final String DEFAULT_ENDPOINT_PATH = "/mcp";
+ private static final String DEFAULT_SUBJECT_HEADER =
"X-ShardingSphere-MCP-Subject";
+
+ private static final String DEFAULT_SOURCE_HEADER =
"X-ShardingSphere-MCP-Source";
+
+ private static final String DEFAULT_ATTRIBUTE_HEADER_PREFIX =
"X-ShardingSphere-MCP-Attribute-";
+
@Override
public YamlHttpTransportConfiguration swapToYamlConfiguration(final
HttpTransportConfiguration data) {
YamlHttpTransportConfiguration result = new
YamlHttpTransportConfiguration();
result.setBindHost(data.getBindHost());
result.setPort(data.getPort());
result.setEndpointPath(data.getEndpointPath());
+
result.setSessionAttributionSource(swapToYamlConfiguration(data.getSessionAttributionSource()));
+ return result;
+ }
+
+ private YamlSessionAttributionSourceConfiguration
swapToYamlConfiguration(final SessionAttributionSourceConfiguration data) {
+ if (null == data) {
+ return null;
+ }
+ YamlSessionAttributionSourceConfiguration result = new
YamlSessionAttributionSourceConfiguration();
+ result.setSubjectHeader(data.getSubjectHeader());
+ result.setSourceHeader(data.getSourceHeader());
+ result.setAttributeHeaderPrefix(data.getAttributeHeaderPrefix());
return result;
}
@@ -50,7 +70,15 @@ public final class YamlHttpTransportConfigurationSwapper
implements YamlConfigur
}
MCPYamlConfigurationValidator.validate(yamlConfig, "MCP HTTP transport
configuration");
return new
HttpTransportConfiguration(getValueOrDefault(yamlConfig.getBindHost(),
DEFAULT_BIND_HOST), null == yamlConfig.getPort() ? DEFAULT_PORT :
yamlConfig.getPort(),
- getValueOrDefault(yamlConfig.getEndpointPath(),
DEFAULT_ENDPOINT_PATH));
+ getValueOrDefault(yamlConfig.getEndpointPath(),
DEFAULT_ENDPOINT_PATH), swapToObject(yamlConfig.getSessionAttributionSource()));
+ }
+
+ private SessionAttributionSourceConfiguration swapToObject(final
YamlSessionAttributionSourceConfiguration yamlConfig) {
+ if (null == yamlConfig) {
+ return null;
+ }
+ return new
SessionAttributionSourceConfiguration(getValueOrDefault(yamlConfig.getSubjectHeader(),
DEFAULT_SUBJECT_HEADER),
+ getValueOrDefault(yamlConfig.getSourceHeader(),
DEFAULT_SOURCE_HEADER),
getValueOrDefault(yamlConfig.getAttributeHeaderPrefix(),
DEFAULT_ATTRIBUTE_HEADER_PREFIX));
}
private String getValueOrDefault(final String value, final String
defaultValue) {
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/validator/HttpTransportConfigurationValidator.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/validator/HttpTransportConfigurationValidator.java
index 52a382d6c18..8292ae49a8d 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/validator/HttpTransportConfigurationValidator.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/validator/HttpTransportConfigurationValidator.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.mcp.bootstrap.config.yaml.validator;
import
org.apache.shardingsphere.mcp.bootstrap.config.yaml.config.YamlHttpTransportConfiguration;
+import
org.apache.shardingsphere.mcp.bootstrap.config.yaml.config.YamlSessionAttributionSourceConfiguration;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
@@ -46,6 +47,9 @@ public final class HttpTransportConfigurationValidator
implements ConstraintVali
addViolation(context, "endpointPath", "must be a single absolute
path without query or fragment");
return false;
}
+ if
(!isValidSessionAttributionSource(value.getSessionAttributionSource(),
context)) {
+ return false;
+ }
return true;
}
@@ -80,6 +84,50 @@ public final class HttpTransportConfigurationValidator
implements ConstraintVali
}
}
+ private boolean isValidSessionAttributionSource(final
YamlSessionAttributionSourceConfiguration value, final
ConstraintValidatorContext context) {
+ if (null == value) {
+ return true;
+ }
+ if (!isValidHeaderName(value.getSubjectHeader())) {
+ addViolation(context, "sessionAttributionSource.subjectHeader",
"must be a valid HTTP header name");
+ return false;
+ }
+ if (!isValidHeaderName(value.getSourceHeader())) {
+ addViolation(context, "sessionAttributionSource.sourceHeader",
"must be a valid HTTP header name");
+ return false;
+ }
+ if (!isValidHeaderName(value.getAttributeHeaderPrefix())) {
+ addViolation(context,
"sessionAttributionSource.attributeHeaderPrefix", "must be a valid HTTP header
name prefix");
+ return false;
+ }
+ return true;
+ }
+
+ private boolean isValidHeaderName(final String value) {
+ if (null == value) {
+ return true;
+ }
+ String actualValue = Objects.toString(value, "").trim();
+ if (actualValue.isEmpty()) {
+ return false;
+ }
+ for (char each : actualValue.toCharArray()) {
+ if (!isHttpTokenCharacter(each)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private boolean isHttpTokenCharacter(final char value) {
+ return value >= '0' && value <= '9' || value >= 'A' && value <= 'Z' ||
value >= 'a' && value <= 'z' || isHttpTokenSymbol(value);
+ }
+
+ private boolean isHttpTokenSymbol(final char value) {
+ return '!' == value || '#' == value || '$' == value || '%' == value ||
'&' == value || '\'' == value || '*' == value || '+' == value || '-' == value
|| '.' == value
+ || '^' == value || '_' == value || '`' == value || '|' ==
value || '~' == value;
+ }
+
private void addViolation(final ConstraintValidatorContext context, final
String propertyName, final String message) {
context.disableDefaultConstraintViolation();
context.buildConstraintViolationWithTemplate(message).addPropertyNode(propertyName).addConstraintViolation();
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/SessionAttributionResolver.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/SessionAttributionResolver.java
new file mode 100644
index 00000000000..717338a7619
--- /dev/null
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/SessionAttributionResolver.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.mcp.bootstrap.transport.server.http;
+
+import jakarta.servlet.http.HttpServletRequest;
+import lombok.Getter;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
+import
org.apache.shardingsphere.mcp.bootstrap.config.SessionAttributionSourceConfiguration;
+
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * Session attribution resolver.
+ */
+@Getter
+public final class SessionAttributionResolver {
+
+ private final SessionAttributionSourceConfiguration config;
+
+ public SessionAttributionResolver(final
SessionAttributionSourceConfiguration config) {
+ this.config = config;
+ }
+
+ /**
+ * Resolve session attribution from HTTP request.
+ *
+ * @param request HTTP request
+ * @return session attribution
+ */
+ public Optional<MCPSessionAttribution> resolve(final HttpServletRequest
request) {
+ if (!isEnabled()) {
+ return Optional.empty();
+ }
+ String subject = getHeaderValue(request, config.getSubjectHeader());
+ if (subject.isEmpty()) {
+ return Optional.empty();
+ }
+ return Optional.of(new MCPSessionAttribution(subject,
getHeaderValue(request, config.getSourceHeader()),
+ resolveAttributes(Collections.list(request.getHeaderNames()),
name -> getHeaderValue(request, name))));
+ }
+
+ /**
+ * Resolve session attribution from header map.
+ *
+ * @param headers headers
+ * @return session attribution
+ */
+ public Optional<MCPSessionAttribution> resolve(final Map<String,
List<String>> headers) {
+ if (!isEnabled()) {
+ return Optional.empty();
+ }
+ String subject = getHeaderValue(headers, config.getSubjectHeader());
+ if (subject.isEmpty()) {
+ return Optional.empty();
+ }
+ return Optional.of(new MCPSessionAttribution(subject,
getHeaderValue(headers, config.getSourceHeader()),
+ resolveAttributes(headers.keySet(), name ->
getHeaderValue(headers, name))));
+ }
+
+ /**
+ * Determine whether session attribution is enabled.
+ *
+ * @return true if enabled
+ */
+ public boolean isEnabled() {
+ return null != config;
+ }
+
+ /**
+ * Get summary for diagnostics.
+ *
+ * @return summary
+ */
+ public String getSummary() {
+ return !isEnabled() ? "disabled" : String.format("trusted-header:%s",
config.getSubjectHeader());
+ }
+
+ private Map<String, String> resolveAttributes(final Iterable<String>
headerNames, final HeaderValueReader headerValueReader) {
+ Map<String, String> result = new LinkedHashMap<>();
+ String attributeHeaderPrefix = config.getAttributeHeaderPrefix();
+ if (attributeHeaderPrefix.isEmpty()) {
+ return result;
+ }
+ String normalizedPrefix =
attributeHeaderPrefix.toLowerCase(Locale.ENGLISH);
+ for (String each : headerNames) {
+ String actualHeaderName = Objects.toString(each, "").trim();
+ if
(actualHeaderName.toLowerCase(Locale.ENGLISH).startsWith(normalizedPrefix)) {
+
result.put(actualHeaderName.substring(attributeHeaderPrefix.length()).toLowerCase(Locale.ENGLISH),
headerValueReader.read(actualHeaderName));
+ }
+ }
+ return result;
+ }
+
+ private String getHeaderValue(final HttpServletRequest request, final
String headerName) {
+ return Objects.toString(request.getHeader(headerName), "").trim();
+ }
+
+ private String getHeaderValue(final Map<String, List<String>> headers,
final String headerName) {
+ for (Entry<String, List<String>> entry : headers.entrySet()) {
+ if (headerName.equalsIgnoreCase(entry.getKey()) &&
!entry.getValue().isEmpty()) {
+ return Objects.toString(entry.getValue().get(0), "").trim();
+ }
+ }
+ return "";
+ }
+
+ @FunctionalInterface
+ private interface HeaderValueReader {
+
+ String read(String headerName);
+ }
+}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServlet.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServlet.java
index 4524126f563..8580afc48c5 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServlet.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServlet.java
@@ -59,12 +59,15 @@ final class StreamableHttpMCPServlet extends HttpServlet
implements McpStreamabl
private final MCPSessionExecutionCoordinator sessionExecutionCoordinator;
+ private final SessionAttributionResolver sessionAttributionResolver;
+
private final Map<String, String> sessionProtocolVersions;
private final AtomicBoolean closed;
StreamableHttpMCPServlet(final MCPSessionManager sessionManager, final
McpJsonMapper jsonMapper, final HttpTransportConfiguration config) {
- delegate = createDelegate(sessionManager, jsonMapper,
config.getBindHost(), config.getEndpointPath());
+ sessionAttributionResolver = new
SessionAttributionResolver(config.getSessionAttributionSource());
+ delegate = createDelegate(sessionManager, jsonMapper,
config.getBindHost(), config.getEndpointPath(), sessionAttributionResolver);
this.sessionManager = sessionManager;
sessionExecutionCoordinator = new
MCPSessionExecutionCoordinator(sessionManager);
sessionProtocolVersions = new ConcurrentHashMap<>();
@@ -73,9 +76,9 @@ final class StreamableHttpMCPServlet extends HttpServlet
implements McpStreamabl
}
private static HttpServletStreamableServerTransportProvider
createDelegate(final MCPSessionManager sessionManager, final McpJsonMapper
jsonMapper,
-
final String bindHost, final String endpointPath) {
+
final String bindHost, final String endpointPath, final
SessionAttributionResolver sessionAttributionResolver) {
return
HttpServletStreamableServerTransportProvider.builder().jsonMapper(jsonMapper).mcpEndpoint(endpointPath)
-
.securityValidator(ServerTransportSecurityValidatorFactory.create(sessionManager,
bindHost)).build();
+
.securityValidator(ServerTransportSecurityValidatorFactory.create(sessionManager,
bindHost, sessionAttributionResolver)).build();
}
@Override
@@ -138,7 +141,9 @@ final class StreamableHttpMCPServlet extends HttpServlet
implements McpStreamabl
response.sendError(HttpServletResponse.SC_UNSUPPORTED_MEDIA_TYPE,
"Content-Type must be application/json.");
return;
}
- serviceRequest(request, withInitializeProtocolHeader(response));
+ SessionAwareHttpServletResponse actualResponse =
withInitializeProtocolHeader(response);
+ serviceRequest(request, actualResponse);
+ bindSessionAttribution(request, actualResponse);
}
private boolean isJsonContentType(final HttpServletRequest request) {
@@ -179,30 +184,56 @@ final class StreamableHttpMCPServlet extends HttpServlet
implements McpStreamabl
}
}
- private HttpServletResponse withInitializeProtocolHeader(final
HttpServletResponse response) {
- return new HttpServletResponseWrapper(response) {
+ private void bindSessionAttribution(final HttpServletRequest request,
final SessionAwareHttpServletResponse response) {
+ String sessionId = response.getSessionId();
+ if (sessionId.isEmpty()) {
+ return;
+ }
+
sessionAttributionResolver.resolve(request).ifPresent(sessionAttribution ->
sessionManager.bindSessionAttribution(sessionId, sessionAttribution));
+ }
+
+ private SessionAwareHttpServletResponse withInitializeProtocolHeader(final
HttpServletResponse response) {
+ return new SessionAwareHttpServletResponse(response) {
@Override
public void setHeader(final String name, final String value) {
super.setHeader(name, value);
- addNegotiatedProtocolHeader(name, value);
+ addNegotiatedProtocolHeader(this, name, value);
}
@Override
public void addHeader(final String name, final String value) {
super.addHeader(name, value);
- addNegotiatedProtocolHeader(name, value);
- }
-
- private void addNegotiatedProtocolHeader(final String name, final
String sessionId) {
- if (SESSION_HEADER.equalsIgnoreCase(name)) {
- super.setHeader(PROTOCOL_HEADER,
findNegotiatedProtocolVersion(sessionId));
- }
+ addNegotiatedProtocolHeader(this, name, value);
}
};
}
+ private void addNegotiatedProtocolHeader(final
SessionAwareHttpServletResponse response, final String name, final String
sessionId) {
+ if (SESSION_HEADER.equalsIgnoreCase(name)) {
+ response.setSessionId(sessionId);
+ response.setHeader(PROTOCOL_HEADER,
findNegotiatedProtocolVersion(sessionId));
+ }
+ }
+
private String findNegotiatedProtocolVersion(final String sessionId) {
return
sessionProtocolVersions.getOrDefault(Objects.toString(sessionId, ""),
MCPTransportConstants.PROTOCOL_VERSION);
}
+
+ private abstract static class SessionAwareHttpServletResponse extends
HttpServletResponseWrapper {
+
+ private String sessionId = "";
+
+ SessionAwareHttpServletResponse(final HttpServletResponse response) {
+ super(response);
+ }
+
+ protected final String getSessionId() {
+ return sessionId;
+ }
+
+ protected final void setSessionId(final String sessionId) {
+ this.sessionId = Objects.toString(sessionId, "").trim();
+ }
+ }
}
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactory.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactory.java
index 269837f5eb5..8677ad80357 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactory.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactory.java
@@ -21,6 +21,7 @@ import
io.modelcontextprotocol.server.transport.ServerTransportSecurityValidator
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import
org.apache.shardingsphere.mcp.bootstrap.transport.HttpTransportHostUtils;
+import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.SessionAttributionResolver;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.OriginHeaderConstraint;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.ProtocolVersionHeaderConstraint;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.TransportHeaderConstraint;
@@ -40,10 +41,11 @@ public final class ServerTransportSecurityValidatorFactory {
*
* @param sessionManager session manager
* @param bindHost bind host
+ * @param sessionAttributionResolver session attribution resolver
* @return transport security validator
*/
- public static ServerTransportSecurityValidator create(final
MCPSessionManager sessionManager, final String bindHost) {
- return new
ShardingSphereServerTransportSecurityValidator(sessionManager,
createConstraints(bindHost));
+ public static ServerTransportSecurityValidator create(final
MCPSessionManager sessionManager, final String bindHost, final
SessionAttributionResolver sessionAttributionResolver) {
+ return new
ShardingSphereServerTransportSecurityValidator(sessionManager,
createConstraints(bindHost), sessionAttributionResolver);
}
private static List<TransportHeaderConstraint> createConstraints(final
String bindHost) {
diff --git
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidator.java
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidator.java
index 1f92b0db00e..d8bb52c2eef 100644
---
a/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidator.java
+++
b/mcp/bootstrap/src/main/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidator.java
@@ -21,6 +21,8 @@ import
io.modelcontextprotocol.server.transport.ServerTransportSecurityException
import
io.modelcontextprotocol.server.transport.ServerTransportSecurityValidator;
import io.modelcontextprotocol.spec.HttpHeaders;
import lombok.RequiredArgsConstructor;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
+import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.SessionAttributionResolver;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.SessionRequiredTransportHeaderConstraint;
import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.TransportHeaderConstraint;
import org.apache.shardingsphere.mcp.core.session.MCPSessionManager;
@@ -28,6 +30,7 @@ import
org.apache.shardingsphere.mcp.core.session.MCPSessionManager;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
/**
* ShardingSphere server transport security validator.
@@ -39,8 +42,11 @@ public final class
ShardingSphereServerTransportSecurityValidator implements Ser
private final List<TransportHeaderConstraint> constraints;
+ private final SessionAttributionResolver sessionAttributionResolver;
+
@Override
public void validateHeaders(final Map<String, List<String>> headers)
throws ServerTransportSecurityException {
+ validateSessionAttribution(headers);
for (TransportHeaderConstraint each : constraints) {
if (each instanceof SessionRequiredTransportHeaderConstraint) {
String sessionId = getFirstHeaderValue(headers,
HttpHeaders.MCP_SESSION_ID);
@@ -52,6 +58,24 @@ public final class
ShardingSphereServerTransportSecurityValidator implements Ser
}
}
+ private void validateSessionAttribution(final Map<String, List<String>>
headers) throws ServerTransportSecurityException {
+ String sessionId = getFirstHeaderValue(headers,
HttpHeaders.MCP_SESSION_ID);
+ if (sessionId.isEmpty() || !sessionManager.hasSession(sessionId) ||
!sessionAttributionResolver.isEnabled()) {
+ return;
+ }
+ Optional<MCPSessionAttribution> sessionAttribution =
sessionAttributionResolver.resolve(headers);
+ if (sessionAttribution.isEmpty()) {
+ return;
+ }
+ Optional<MCPSessionAttribution> boundSessionAttribution =
sessionManager.findSessionAttribution(sessionId);
+ if (boundSessionAttribution.isEmpty()) {
+ throw new ServerTransportSecurityException(400,
String.format("Session attribution is not bound for session `%s`.", sessionId));
+ }
+ if (!boundSessionAttribution.get().equals(sessionAttribution.get())) {
+ throw new ServerTransportSecurityException(400,
String.format("Session attribution does not match existing binding for session
`%s`.", sessionId));
+ }
+ }
+
private String getFirstHeaderValue(final Map<String, List<String>>
headers, final String headerName) {
return headers.entrySet().stream()
.filter(entry -> headerName.equalsIgnoreCase(entry.getKey())
&& !entry.getValue().isEmpty()).findFirst().map(optional ->
Objects.toString(optional.getValue().get(0), "").trim())
diff --git
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapperTest.java
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapperTest.java
index ef526729c7a..216d74ec08c 100644
---
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapperTest.java
+++
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/config/yaml/swapper/YamlHttpTransportConfigurationSwapperTest.java
@@ -18,11 +18,14 @@
package org.apache.shardingsphere.mcp.bootstrap.config.yaml.swapper;
import
org.apache.shardingsphere.mcp.bootstrap.config.HttpTransportConfiguration;
+import
org.apache.shardingsphere.mcp.bootstrap.config.SessionAttributionSourceConfiguration;
import
org.apache.shardingsphere.mcp.bootstrap.config.yaml.config.YamlHttpTransportConfiguration;
+import
org.apache.shardingsphere.mcp.bootstrap.config.yaml.config.YamlSessionAttributionSourceConfiguration;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
+import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
class YamlHttpTransportConfigurationSwapperTest {
@@ -35,6 +38,7 @@ class YamlHttpTransportConfigurationSwapperTest {
assertThat(actual.getBindHost(), is("127.0.0.1"));
assertThat(actual.getPort(), is(18088));
assertThat(actual.getEndpointPath(), is("/mcp"));
+ assertNull(actual.getSessionAttributionSource());
}
@Test
@@ -121,6 +125,50 @@ class YamlHttpTransportConfigurationSwapperTest {
assertThat(actual.getBindHost(), is("127.0.0.1"));
assertThat(actual.getPort(), is(18088));
assertThat(actual.getEndpointPath(), is("/mcp"));
+ assertNull(actual.getSessionAttributionSource());
+ }
+
+ @Test
+ void assertSwapToObjectWithSessionAttributionSource() {
+ YamlHttpTransportConfiguration yamlConfig =
createYamlConfig("127.0.0.1", 18088, "/mcp");
+
yamlConfig.setSessionAttributionSource(createSessionAttributionSource("X-Test-Subject",
"X-Test-Source", "X-Test-Attr-"));
+ HttpTransportConfiguration actual = swapper.swapToObject(yamlConfig);
+ assertThat(actual.getSessionAttributionSource().getSubjectHeader(),
is("X-Test-Subject"));
+ assertThat(actual.getSessionAttributionSource().getSourceHeader(),
is("X-Test-Source"));
+
assertThat(actual.getSessionAttributionSource().getAttributeHeaderPrefix(),
is("X-Test-Attr-"));
+ }
+
+ @Test
+ void assertSwapToObjectWithInvalidSessionAttributionSubjectHeader() {
+ YamlHttpTransportConfiguration yamlConfig =
createYamlConfig("127.0.0.1", 18088, "/mcp");
+
yamlConfig.setSessionAttributionSource(createSessionAttributionSource("X-Test,Subject",
"X-Test-Source", "X-Test-Attr-"));
+ IllegalArgumentException actual =
assertThrows(IllegalArgumentException.class, () ->
swapper.swapToObject(yamlConfig));
+ assertThat(actual.getMessage(), is("MCP HTTP transport configuration
property `sessionAttributionSource.subjectHeader` must be a valid HTTP header
name."));
+ }
+
+ @Test
+ void assertSwapToObjectWithInvalidSessionAttributionSourceHeader() {
+ YamlHttpTransportConfiguration yamlConfig =
createYamlConfig("127.0.0.1", 18088, "/mcp");
+
yamlConfig.setSessionAttributionSource(createSessionAttributionSource("X-Test-Subject",
"X/Test-Source", "X-Test-Attr-"));
+ IllegalArgumentException actual =
assertThrows(IllegalArgumentException.class, () ->
swapper.swapToObject(yamlConfig));
+ assertThat(actual.getMessage(), is("MCP HTTP transport configuration
property `sessionAttributionSource.sourceHeader` must be a valid HTTP header
name."));
+ }
+
+ @Test
+ void
assertSwapToObjectWithInvalidSessionAttributionAttributeHeaderPrefix() {
+ YamlHttpTransportConfiguration yamlConfig =
createYamlConfig("127.0.0.1", 18088, "/mcp");
+
yamlConfig.setSessionAttributionSource(createSessionAttributionSource("X-Test-Subject",
"X-Test-Source", "X-Test-Attr("));
+ IllegalArgumentException actual =
assertThrows(IllegalArgumentException.class, () ->
swapper.swapToObject(yamlConfig));
+ assertThat(actual.getMessage(), is("MCP HTTP transport configuration
property `sessionAttributionSource.attributeHeaderPrefix` must be a valid HTTP
header name prefix."));
+ }
+
+ @Test
+ void assertSwapToYamlConfigurationWithSessionAttributionSource() {
+ YamlHttpTransportConfiguration actual =
swapper.swapToYamlConfiguration(
+ new HttpTransportConfiguration("127.0.0.1", 18088, "/mcp", new
SessionAttributionSourceConfiguration("X-Test-Subject", "X-Test-Source",
"X-Test-Attr-")));
+ assertThat(actual.getSessionAttributionSource().getSubjectHeader(),
is("X-Test-Subject"));
+ assertThat(actual.getSessionAttributionSource().getSourceHeader(),
is("X-Test-Source"));
+
assertThat(actual.getSessionAttributionSource().getAttributeHeaderPrefix(),
is("X-Test-Attr-"));
}
private YamlHttpTransportConfiguration createYamlConfig(final String
bindHost, final Integer port, final String endpointPath) {
@@ -130,4 +178,12 @@ class YamlHttpTransportConfigurationSwapperTest {
result.setEndpointPath(endpointPath);
return result;
}
+
+ private YamlSessionAttributionSourceConfiguration
createSessionAttributionSource(final String subjectHeader, final String
sourceHeader, final String attributeHeaderPrefix) {
+ YamlSessionAttributionSourceConfiguration result = new
YamlSessionAttributionSourceConfiguration();
+ result.setSubjectHeader(subjectHeader);
+ result.setSourceHeader(sourceHeader);
+ result.setAttributeHeaderPrefix(attributeHeaderPrefix);
+ return result;
+ }
}
diff --git
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/SessionAttributionResolverTest.java
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/SessionAttributionResolverTest.java
new file mode 100644
index 00000000000..39441ba7ad8
--- /dev/null
+++
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/SessionAttributionResolverTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.mcp.bootstrap.transport.server.http;
+
+import jakarta.servlet.http.HttpServletRequest;
+import
org.apache.shardingsphere.mcp.bootstrap.config.SessionAttributionSourceConfiguration;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+class SessionAttributionResolverTest {
+
+ @Test
+ void assertResolveFromRequest() {
+ HttpServletRequest request = mock(HttpServletRequest.class);
+ when(request.getHeader("X-Test-Subject")).thenReturn("subject");
+ when(request.getHeader("X-Test-Source")).thenReturn("gateway");
+
when(request.getHeaderNames()).thenReturn(Collections.enumeration(List.of("X-Test-Subject",
"X-Test-Source", "X-Test-ATTR-Region")));
+ when(request.getHeader("X-Test-ATTR-Region")).thenReturn("ap-south");
+ SessionAttributionResolver resolver = new
SessionAttributionResolver(new
SessionAttributionSourceConfiguration("X-Test-Subject", "X-Test-Source",
"X-Test-Attr-"));
+ assertThat(resolver.resolve(request).map(each ->
each.getAttributes().get("region")), is(Optional.of("ap-south")));
+ }
+
+ @Test
+ void assertResolveFromHeaders() {
+ SessionAttributionResolver resolver = new
SessionAttributionResolver(new
SessionAttributionSourceConfiguration("X-Test-Subject", "X-Test-Source",
"X-Test-Attr-"));
+ assertThat(resolver.resolve(Map.of("X-Test-Subject",
List.of("subject"), "X-Test-Source", List.of("gateway"), "x-test-attr-Region",
List.of("ap-south"))).get()
+ .getAttributes().get("region"), is("ap-south"));
+ }
+}
diff --git
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServletTest.java
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServletTest.java
index 4bc70f80140..eef1dfdf973 100644
---
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServletTest.java
+++
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/StreamableHttpMCPServletTest.java
@@ -26,7 +26,9 @@ import io.modelcontextprotocol.spec.ProtocolVersions;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
import
org.apache.shardingsphere.mcp.bootstrap.config.HttpTransportConfiguration;
+import
org.apache.shardingsphere.mcp.bootstrap.config.SessionAttributionSourceConfiguration;
import org.apache.shardingsphere.mcp.bootstrap.transport.MCPTransportConstants;
import
org.apache.shardingsphere.mcp.bootstrap.transport.MCPTransportJsonMapperFactory;
import
org.apache.shardingsphere.mcp.core.session.MCPSessionExecutionCoordinator;
@@ -42,6 +44,7 @@ import reactor.core.publisher.Mono;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
@@ -231,6 +234,30 @@ class StreamableHttpMCPServletTest {
verify(response).setHeader(HttpHeaders.PROTOCOL_VERSION,
MCPTransportConstants.PROTOCOL_VERSION);
}
+ @Test
+ void assertServicePostBindSessionAttribution() throws ServletException,
IOException, ReflectiveOperationException {
+ HttpServletStreamableServerTransportProvider delegate =
mock(HttpServletStreamableServerTransportProvider.class);
+ MCPSessionManager sessionManager = mock(MCPSessionManager.class);
+ HttpServletRequest request = mock(HttpServletRequest.class);
+ when(request.getMethod()).thenReturn("POST");
+ when(request.getHeader(HttpHeaders.ACCEPT)).thenReturn(ACCEPT);
+
when(request.getHeaderNames()).thenReturn(Collections.enumeration(List.of("X-Test-Subject",
"X-Test-Source", "X-Test-ATTR-Region")));
+ when(request.getHeader("X-Test-Subject")).thenReturn("subject");
+ when(request.getHeader("X-Test-Source")).thenReturn("gateway");
+ when(request.getHeader("X-Test-ATTR-Region")).thenReturn("ap-south");
+ HttpServletResponse response = mock(HttpServletResponse.class);
+ when(delegate.closeGracefully()).thenReturn(Mono.empty());
+ HttpTransportConfiguration config = new
HttpTransportConfiguration("127.0.0.1", 18088, "/mcp",
+ new SessionAttributionSourceConfiguration("X-Test-Subject",
"X-Test-Source", "X-Test-Attr-"));
+ StreamableHttpMCPServlet actual = createServlet(delegate,
sessionManager, mock(MCPSessionExecutionCoordinator.class), config);
+ doAnswer(invocation -> {
+ ((HttpServletResponse)
invocation.getArgument(1)).setHeader(HttpHeaders.MCP_SESSION_ID, "session-id");
+ return null;
+ }).when(delegate).service(any(HttpServletRequest.class),
any(HttpServletResponse.class));
+ actual.service(request, response);
+ verify(sessionManager).bindSessionAttribution("session-id", new
MCPSessionAttribution("subject", "gateway", Map.of("region", "ap-south")));
+ }
+
@Test
void assertServicePostWithJsonContentType() throws ServletException,
IOException, ReflectiveOperationException {
HttpServletStreamableServerTransportProvider delegate =
mock(HttpServletStreamableServerTransportProvider.class);
diff --git
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactoryTest.java
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactoryTest.java
index 4273fb44b36..03a86ce2909 100644
---
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactoryTest.java
+++
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ServerTransportSecurityValidatorFactoryTest.java
@@ -19,6 +19,7 @@ package
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator;
import
io.modelcontextprotocol.server.transport.ServerTransportSecurityException;
import
io.modelcontextprotocol.server.transport.ServerTransportSecurityValidator;
+import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.SessionAttributionResolver;
import org.apache.shardingsphere.mcp.core.session.MCPSessionManager;
import org.junit.jupiter.api.Test;
@@ -30,15 +31,18 @@ import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
class ServerTransportSecurityValidatorFactoryTest {
+ private static final SessionAttributionResolver
DISABLED_SESSION_ATTRIBUTION_RESOLVER = new SessionAttributionResolver(null);
+
@Test
void assertCreateWithoutOptionalRules() {
MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- ServerTransportSecurityValidator actual =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1");
+ ServerTransportSecurityValidator actual =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1",
DISABLED_SESSION_ATTRIBUTION_RESOLVER);
assertDoesNotThrow(() -> actual.validateHeaders(Map.of()));
verifyNoInteractions(sessionManager);
}
@@ -46,7 +50,7 @@ class ServerTransportSecurityValidatorFactoryTest {
@Test
void assertCreateWithLoopbackOrigin() {
MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- ServerTransportSecurityValidator actual =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1");
+ ServerTransportSecurityValidator actual =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1",
DISABLED_SESSION_ATTRIBUTION_RESOLVER);
assertDoesNotThrow(() -> actual.validateHeaders(Map.of("Origin",
List.of("http://127.0.0.1:8080"))));
verifyNoInteractions(sessionManager);
}
@@ -54,17 +58,18 @@ class ServerTransportSecurityValidatorFactoryTest {
@Test
void assertCreateWithLoopbackOriginRejectsRemoteOrigin() {
MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1");
+ when(sessionManager.hasSession("session-id")).thenReturn(true);
+ ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1",
DISABLED_SESSION_ATTRIBUTION_RESOLVER);
ServerTransportSecurityException ex =
assertThrows(ServerTransportSecurityException.class,
() -> validator.validateHeaders(Map.of("Origin",
List.of("http://example.com:8080"), "Mcp-Session-Id", List.of("session-id"))));
assertThat(ex.getStatusCode(), is(403));
- verifyNoInteractions(sessionManager);
+ verify(sessionManager).hasSession("session-id");
}
@Test
void assertCreateWithNonLoopbackBindingAcceptsMissingOrigin() {
MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- ServerTransportSecurityValidator actual =
ServerTransportSecurityValidatorFactory.create(sessionManager, "0.0.0.0");
+ ServerTransportSecurityValidator actual =
ServerTransportSecurityValidatorFactory.create(sessionManager, "0.0.0.0",
DISABLED_SESSION_ATTRIBUTION_RESOLVER);
assertDoesNotThrow(() -> actual.validateHeaders(Map.of()));
verifyNoInteractions(sessionManager);
}
@@ -72,7 +77,7 @@ class ServerTransportSecurityValidatorFactoryTest {
@Test
void assertCreateWithNonLoopbackBindingRejectsPresentOrigin() {
MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "0.0.0.0");
+ ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "0.0.0.0",
DISABLED_SESSION_ATTRIBUTION_RESOLVER);
ServerTransportSecurityException ex =
assertThrows(ServerTransportSecurityException.class,
() -> validator.validateHeaders(Map.of("Origin",
List.of("https://gateway.example.test"))));
assertThat(ex.getStatusCode(), is(403));
@@ -82,7 +87,7 @@ class ServerTransportSecurityValidatorFactoryTest {
@Test
void assertCreateRejectsInvalidOrigin() {
MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1");
+ ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1",
DISABLED_SESSION_ATTRIBUTION_RESOLVER);
ServerTransportSecurityException ex =
assertThrows(ServerTransportSecurityException.class, () ->
validator.validateHeaders(Map.of("Origin", List.of("://bad-origin"))));
assertThat(ex.getStatusCode(), is(403));
verifyNoInteractions(sessionManager);
@@ -92,7 +97,7 @@ class ServerTransportSecurityValidatorFactoryTest {
void assertCreateWithProtocolVersionConstraintLast() {
MCPSessionManager sessionManager = mock(MCPSessionManager.class);
when(sessionManager.hasSession("session-id")).thenReturn(true);
- ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1");
+ ServerTransportSecurityValidator validator =
ServerTransportSecurityValidatorFactory.create(sessionManager, "127.0.0.1",
DISABLED_SESSION_ATTRIBUTION_RESOLVER);
ServerTransportSecurityException ex =
assertThrows(ServerTransportSecurityException.class,
() -> validator.validateHeaders(Map.of("Mcp-Session-Id",
List.of("session-id"))));
assertThat(ex.getStatusCode(), is(400));
diff --git
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidatorTest.java
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidatorTest.java
index e9665c41ce9..fd9d1006ced 100644
---
a/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidatorTest.java
+++
b/mcp/bootstrap/src/test/java/org/apache/shardingsphere/mcp/bootstrap/transport/server/http/validator/ShardingSphereServerTransportSecurityValidatorTest.java
@@ -19,94 +19,56 @@ package
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator;
import
io.modelcontextprotocol.server.transport.ServerTransportSecurityException;
import io.modelcontextprotocol.spec.HttpHeaders;
-import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.ProtocolVersionHeaderConstraint;
-import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.SessionRequiredTransportHeaderConstraint;
-import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.validator.constraint.TransportHeaderConstraint;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
+import
org.apache.shardingsphere.mcp.bootstrap.config.SessionAttributionSourceConfiguration;
+import
org.apache.shardingsphere.mcp.bootstrap.transport.server.http.SessionAttributionResolver;
import org.apache.shardingsphere.mcp.core.session.MCPSessionManager;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.Arguments;
-import org.junit.jupiter.params.provider.MethodSource;
import java.util.List;
import java.util.Map;
-import java.util.stream.Stream;
+import java.util.Optional;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoInteractions;
-import static org.mockito.Mockito.when;
class ShardingSphereServerTransportSecurityValidatorTest {
@Test
- void assertValidateHeadersWithNoFailure() {
- ShardingSphereServerTransportSecurityValidator validator = new
ShardingSphereServerTransportSecurityValidator(mock(), List.of());
- assertDoesNotThrow(() -> validator.validateHeaders(Map.of()));
+ void assertValidateHeadersWithMatchedSessionAttribution() {
+ MCPSessionManager sessionManager = new MCPSessionManager(Map.of());
+ sessionManager.createSession("session-1");
+ sessionManager.bindSessionAttribution("session-1", new
MCPSessionAttribution("subject", "gateway", Map.of()));
+ ShardingSphereServerTransportSecurityValidator actual = new
ShardingSphereServerTransportSecurityValidator(sessionManager, List.of(),
+ new SessionAttributionResolver(new
SessionAttributionSourceConfiguration("X-Test-Subject", "X-Test-Source",
"X-Test-Attr-")));
+ assertDoesNotThrow(() ->
actual.validateHeaders(Map.of(HttpHeaders.MCP_SESSION_ID, List.of("session-1"),
"X-Test-Subject", List.of("subject"),
+ "X-Test-Source", List.of("gateway"))));
}
@Test
- void assertValidateHeadersWithFailure() throws
ServerTransportSecurityException {
- TransportHeaderConstraint first =
mock(TransportHeaderConstraint.class);
- doThrow(new ServerTransportSecurityException(401,
"Unauthorized.")).when(first).validate("");
- TransportHeaderConstraint second =
mock(TransportHeaderConstraint.class);
- ShardingSphereServerTransportSecurityValidator validator = new
ShardingSphereServerTransportSecurityValidator(mock(), List.of(first, second));
- ServerTransportSecurityException actual =
assertThrows(ServerTransportSecurityException.class, () ->
validator.validateHeaders(Map.of()));
- assertThat(actual.getStatusCode(), is(401));
- assertThat(actual.getMessage(), is("Unauthorized."));
- verifyNoInteractions(second);
+ void assertValidateHeadersWithMismatchedSessionAttribution() {
+ MCPSessionManager sessionManager = new MCPSessionManager(Map.of());
+ sessionManager.createSession("session-1");
+ sessionManager.bindSessionAttribution("session-1", new
MCPSessionAttribution("subject", "gateway", Map.of()));
+ ShardingSphereServerTransportSecurityValidator actual = new
ShardingSphereServerTransportSecurityValidator(sessionManager, List.of(),
+ new SessionAttributionResolver(new
SessionAttributionSourceConfiguration("X-Test-Subject", "X-Test-Source",
"X-Test-Attr-")));
+ ServerTransportSecurityException exception =
assertThrows(ServerTransportSecurityException.class, () ->
actual.validateHeaders(
+ Map.of(HttpHeaders.MCP_SESSION_ID, List.of("session-1"),
"X-Test-Subject", List.of("other"), "X-Test-Source", List.of("gateway"))));
+ assertThat(exception.getMessage(), is("Session attribution does not
match existing binding for session `session-1`."));
}
@Test
- void assertValidateHeadersWithExistingSessionMissingProtocolVersion() {
- MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- when(sessionManager.hasSession("session-id")).thenReturn(true);
- ShardingSphereServerTransportSecurityValidator validator = new
ShardingSphereServerTransportSecurityValidator(sessionManager, List.of(new
ProtocolVersionHeaderConstraint()));
- ServerTransportSecurityException actual =
assertThrows(ServerTransportSecurityException.class,
- () ->
validator.validateHeaders(Map.of(HttpHeaders.MCP_SESSION_ID,
List.of("session-id"))));
- assertThat(actual.getStatusCode(), is(400));
- assertThat(actual.getMessage(), is("MCP-Protocol-Version header is
required."));
- verify(sessionManager).hasSession("session-id");
- }
-
- @ParameterizedTest(name = "{0}")
- @MethodSource("assertValidateHeadersWithSessionConstraintArguments")
- void assertValidateHeadersWithSessionConstraint(final String name, final
Map<String, List<String>> headers,
- final String
expectedSessionId, final boolean sessionExists, final String
expectedConstraintValue) throws ServerTransportSecurityException {
- MCPSessionManager sessionManager = mock(MCPSessionManager.class);
- SessionRequiredTransportHeaderConstraint constraint =
mock(SessionRequiredTransportHeaderConstraint.class);
- ShardingSphereServerTransportSecurityValidator validator = new
ShardingSphereServerTransportSecurityValidator(sessionManager,
List.of(constraint));
- if (null != expectedSessionId) {
-
when(sessionManager.hasSession(expectedSessionId)).thenReturn(sessionExists);
- }
- if (null != expectedConstraintValue) {
- when(constraint.getConstraintKey()).thenReturn("Authorization");
- }
- assertDoesNotThrow(() -> validator.validateHeaders(headers));
- if (null == expectedSessionId) {
- verifyNoInteractions(sessionManager);
- } else {
- verify(sessionManager).hasSession(expectedSessionId);
- }
- if (null == expectedConstraintValue) {
- verifyNoInteractions(constraint);
- } else {
- verify(constraint).getConstraintKey();
- verify(constraint).validate(expectedConstraintValue);
- }
- }
-
- private static Stream<Arguments>
assertValidateHeadersWithSessionConstraintArguments() {
- return Stream.of(
- Arguments.of("skip without session id",
Map.of("Authorization", List.of("Bearer foo_token")), null, false, null),
- Arguments.of("skip with empty session header",
Map.of("Mcp-Session-Id", List.<String>of()), null, false, null),
- Arguments.of("skip for unknown session",
Map.of("Mcp-Session-Id", List.of("session-id")), "session-id", false, null),
- Arguments.of("validate with existing session",
- Map.of("mcp-session-id", List.of(" session-id "),
"authorization", List.of(" Bearer foo_token ")), "session-id", true, "Bearer
foo_token"));
+ void assertValidateHeadersWithUnboundSessionAttribution() {
+ MCPSessionManager sessionManager = new MCPSessionManager(Map.of());
+ sessionManager.createSession("session-1");
+ ShardingSphereServerTransportSecurityValidator actual = new
ShardingSphereServerTransportSecurityValidator(sessionManager, List.of(),
+ new SessionAttributionResolver(new
SessionAttributionSourceConfiguration("X-Test-Subject", "X-Test-Source",
"X-Test-Attr-")));
+ ServerTransportSecurityException exception =
assertThrows(ServerTransportSecurityException.class, () ->
actual.validateHeaders(
+ Map.of(HttpHeaders.MCP_SESSION_ID, List.of("session-1"),
"X-Test-Subject", List.of("subject"), "X-Test-Source", List.of("gateway"))));
+ assertThat(exception.getStatusCode(), is(400));
+ assertThat(exception.getMessage(), is("Session attribution is not
bound for session `session-1`."));
+ assertThat(sessionManager.findSessionAttribution("session-1"),
is(Optional.empty()));
}
}
diff --git
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/completion/MCPCompletionService.java
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/completion/MCPCompletionService.java
index e01ae8e2a4f..8b1328bc661 100644
---
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/completion/MCPCompletionService.java
+++
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/completion/MCPCompletionService.java
@@ -136,7 +136,7 @@ public final class MCPCompletionService {
MCPCompletionRequestContext requestContext = new
MCPCompletionRequestContext(sessionId, descriptor, argumentName,
contextArguments);
for (MCPCompletionProvider<?> each : completionProviders) {
if (each.supports(requestContext)) {
- try (MCPRequestScope requestScope = new
MCPRequestScope(runtimeContext)) {
+ try (MCPRequestScope requestScope = new
MCPRequestScope(runtimeContext, sessionId)) {
return completeCandidates(requestScope, each,
requestContext);
}
}
diff --git
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/context/MCPRequestScope.java
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/context/MCPRequestScope.java
index 0688d091197..647277fcf16 100644
---
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/context/MCPRequestScope.java
+++
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/context/MCPRequestScope.java
@@ -19,6 +19,7 @@ package org.apache.shardingsphere.mcp.core.context;
import lombok.AccessLevel;
import lombok.Getter;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
import org.apache.shardingsphere.mcp.core.session.MCPSessionManager;
import
org.apache.shardingsphere.mcp.core.tool.handler.execute.MCPSQLExecutionFacade;
import org.apache.shardingsphere.mcp.core.workflow.WorkflowProxyQueryService;
@@ -33,6 +34,8 @@ import
org.apache.shardingsphere.mcp.support.database.spi.MCPMetadataQueryFacade
import
org.apache.shardingsphere.mcp.support.workflow.MCPWorkflowHandlerContext;
import org.apache.shardingsphere.mcp.support.workflow.WorkflowSessionContext;
+import java.util.Optional;
+
/**
* MCP request scope.
*/
@@ -47,6 +50,8 @@ public final class MCPRequestScope implements
MCPServiceHandlerContext, MCPDatab
@Getter(AccessLevel.NONE)
private final RequestScopedMetadataContext metadataContext;
+ private final Optional<MCPSessionAttribution> sessionAttribution;
+
private final WorkflowSessionContext workflowSessionContext;
private final MCPMetadataQueryFacade metadataQueryFacade;
@@ -61,10 +66,21 @@ public final class MCPRequestScope implements
MCPServiceHandlerContext, MCPDatab
* @param runtimeContext runtime context
*/
public MCPRequestScope(final MCPRuntimeContext runtimeContext) {
+ this(runtimeContext, "");
+ }
+
+ /**
+ * Create MCP request scope.
+ *
+ * @param runtimeContext runtime context
+ * @param sessionId session id
+ */
+ public MCPRequestScope(final MCPRuntimeContext runtimeContext, final
String sessionId) {
MCPSessionManager sessionManager = runtimeContext.getSessionManager();
activeTransport = runtimeContext.getActiveTransport();
databaseCapabilityProvider =
runtimeContext.getDatabaseCapabilityProvider();
metadataContext = new
RequestScopedMetadataContext(sessionManager.getTransactionResourceManager().getRuntimeDatabases(),
databaseCapabilityProvider);
+ sessionAttribution = sessionManager.findSessionAttribution(sessionId);
workflowSessionContext = runtimeContext.getWorkflowSessionContext();
metadataQueryFacade = new
MetadataQueryService(databaseCapabilityProvider, metadataContext);
executionFacade = new
MCPSQLExecutionFacade(databaseCapabilityProvider, sessionManager);
@@ -81,6 +97,11 @@ public final class MCPRequestScope implements
MCPServiceHandlerContext, MCPDatab
return databaseCapabilityProvider;
}
+ @Override
+ public Optional<MCPSessionAttribution> findSessionAttribution() {
+ return sessionAttribution;
+ }
+
@Override
public void close() {
metadataContext.close();
diff --git
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManager.java
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManager.java
index abb2626a588..04ac17377ee 100644
---
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManager.java
+++
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManager.java
@@ -19,12 +19,14 @@ package org.apache.shardingsphere.mcp.core.session;
import lombok.Getter;
import org.apache.shardingsphere.infra.exception.ShardingSpherePreconditions;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
import
org.apache.shardingsphere.mcp.support.database.metadata.jdbc.RuntimeDatabaseConfiguration;
import
org.apache.shardingsphere.mcp.core.tool.handler.execute.MCPJdbcTransactionResourceManager;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ConcurrentHashMap;
@@ -41,6 +43,8 @@ public final class MCPSessionManager {
private final Map<String, ReentrantLock> sessions = new
ConcurrentHashMap<>();
+ private final Map<String, MCPSessionAttribution> sessionAttributions = new
ConcurrentHashMap<>();
+
private final List<Consumer<String>> sessionCloseListeners = new
CopyOnWriteArrayList<>();
public MCPSessionManager(final Map<String, RuntimeDatabaseConfiguration>
databases) {
@@ -56,6 +60,29 @@ public final class MCPSessionManager {
ShardingSpherePreconditions.checkState(null ==
sessions.putIfAbsent(sessionId, new ReentrantLock(true)), () -> new
IllegalStateException("Session already exists."));
}
+ /**
+ * Bind session attribution to one existing session.
+ *
+ * @param sessionId session id
+ * @param sessionAttribution session attribution
+ */
+ public void bindSessionAttribution(final String sessionId, final
MCPSessionAttribution sessionAttribution) {
+ ShardingSpherePreconditions.checkState(hasSession(sessionId),
MCPSessionNotExistedException::new);
+ MCPSessionAttribution existing =
sessionAttributions.putIfAbsent(sessionId, sessionAttribution);
+ ShardingSpherePreconditions.checkState(null == existing ||
existing.equals(sessionAttribution),
+ () -> new IllegalStateException(String.format("Session
attribution does not match existing binding for session `%s`.", sessionId)));
+ }
+
+ /**
+ * Find session attribution.
+ *
+ * @param sessionId session id
+ * @return session attribution
+ */
+ public Optional<MCPSessionAttribution> findSessionAttribution(final String
sessionId) {
+ return Optional.ofNullable(sessionAttributions.get(sessionId));
+ }
+
/**
* Determine whether a session exists.
*
@@ -89,6 +116,7 @@ public final class MCPSessionManager {
transactionResourceManager.closeSession(sessionId);
} finally {
if (sessions.remove(sessionId, executionLock)) {
+ sessionAttributions.remove(sessionId);
notifySessionCloseListeners(sessionId);
}
}
diff --git
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/tool/MCPToolController.java
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/tool/MCPToolController.java
index e97d946feab..8197a710881 100644
---
a/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/tool/MCPToolController.java
+++
b/mcp/core/src/main/java/org/apache/shardingsphere/mcp/core/tool/MCPToolController.java
@@ -70,7 +70,7 @@ public final class MCPToolController {
public MCPResponse handle(final String sessionId, final MCPToolDefinition
toolDefinition, final Map<String, Object> arguments) {
try {
toolCallLimiter.acquire(sessionId,
toolDefinition.getDescriptor().getName());
- try (MCPRequestScope requestScope = new
MCPRequestScope(runtimeContext)) {
+ try (MCPRequestScope requestScope = new
MCPRequestScope(runtimeContext, sessionId)) {
return ToolDefinitionRegistry.dispatch(requestScope,
toolDefinition, sessionId, arguments);
}
// CHECKSTYLE:OFF
diff --git
a/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/context/MCPRequestScopeTest.java
b/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/context/MCPRequestScopeTest.java
new file mode 100644
index 00000000000..572c7c19149
--- /dev/null
+++
b/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/context/MCPRequestScopeTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.mcp.core.context;
+
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
+import org.apache.shardingsphere.mcp.core.session.MCPSessionManager;
+import
org.apache.shardingsphere.mcp.support.database.capability.MCPDatabaseCapabilityProvider;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+import java.util.Optional;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
+
+class MCPRequestScopeTest {
+
+ @Test
+ void assertFindSessionAttribution() {
+ MCPSessionManager sessionManager = new MCPSessionManager(Map.of());
+ MCPSessionAttribution sessionAttribution = new
MCPSessionAttribution("subject", "gateway", Map.of("cluster", "demo"));
+ sessionManager.createSession("session-1");
+ sessionManager.bindSessionAttribution("session-1", sessionAttribution);
+ MCPRuntimeContext runtimeContext = new
MCPRuntimeContext(sessionManager, new MCPDatabaseCapabilityProvider(Map.of()),
"http");
+ try (MCPRequestScope requestScope = new
MCPRequestScope(runtimeContext, "session-1")) {
+ assertThat(requestScope.findSessionAttribution(),
is(Optional.of(sessionAttribution)));
+ }
+ }
+
+ @Test
+ void assertFindSessionAttributionWithoutSession() {
+ MCPRuntimeContext runtimeContext = new MCPRuntimeContext(new
MCPSessionManager(Map.of()), new MCPDatabaseCapabilityProvider(Map.of()),
"http");
+ try (MCPRequestScope requestScope = new
MCPRequestScope(runtimeContext)) {
+ assertThat(requestScope.findSessionAttribution(),
is(Optional.empty()));
+ }
+ }
+}
diff --git
a/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/resource/handler/capability/ServerCapabilitiesHandlerTest.java
b/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/resource/handler/capability/ServerCapabilitiesHandlerTest.java
index dd6c8006ef0..673f93ae490 100644
---
a/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/resource/handler/capability/ServerCapabilitiesHandlerTest.java
+++
b/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/resource/handler/capability/ServerCapabilitiesHandlerTest.java
@@ -194,7 +194,7 @@ class ServerCapabilitiesHandlerTest {
assertTrue(actual.containsKey("stdio_stdout"));
Map<?, ?> actualClientSafetyPolicy = (Map<?, ?>)
actual.get("client_safety_policy");
assertThat(actualClientSafetyPolicy.get("identity_scope"),
is("mcp_session"));
-
assertTrue(String.valueOf(actualClientSafetyPolicy.get("transport_scope")).contains("no
built-in authorization"));
+
assertTrue(String.valueOf(actualClientSafetyPolicy.get("transport_scope")).contains("trusted
session attribution"));
assertThat(((Map<?, ?>)
actualClientSafetyPolicy.get("tool_call_limit")).get("scope"), is("session"));
assertTrue(String.valueOf(actualClientSafetyPolicy.get("abuse_guard")).contains("counted
before dispatch"));
}
diff --git
a/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManagerTest.java
b/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManagerTest.java
index debe6ca5992..ebb6e8b9d37 100644
---
a/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManagerTest.java
+++
b/mcp/core/src/test/java/org/apache/shardingsphere/mcp/core/session/MCPSessionManagerTest.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.mcp.core.session;
+import org.apache.shardingsphere.mcp.api.session.MCPSessionAttribution;
import
org.apache.shardingsphere.mcp.support.database.metadata.jdbc.RuntimeDatabaseConfiguration;
import org.junit.jupiter.api.Test;
@@ -26,6 +27,7 @@ import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
@@ -132,4 +134,32 @@ class MCPSessionManagerTest {
assertFalse(sessionManager.hasSession("session-1"));
assertFalse(sessionManager.hasSession("session-2"));
}
+
+ @Test
+ void assertBindSessionAttribution() {
+ MCPSessionManager sessionManager = new
MCPSessionManager(Collections.emptyMap());
+ MCPSessionAttribution sessionAttribution = new
MCPSessionAttribution("subject", "gateway", Map.of("region", "ap-south"));
+ sessionManager.createSession("session-1");
+ assertDoesNotThrow(() ->
sessionManager.bindSessionAttribution("session-1", sessionAttribution));
+ assertThat(sessionManager.findSessionAttribution("session-1"),
is(Optional.of(sessionAttribution)));
+ }
+
+ @Test
+ void assertBindSessionAttributionWithDifferentBinding() {
+ MCPSessionManager sessionManager = new
MCPSessionManager(Collections.emptyMap());
+ sessionManager.createSession("session-1");
+ sessionManager.bindSessionAttribution("session-1", new
MCPSessionAttribution("subject", "gateway", Map.of()));
+ IllegalStateException actual =
assertThrows(IllegalStateException.class,
+ () -> sessionManager.bindSessionAttribution("session-1", new
MCPSessionAttribution("other", "gateway", Map.of())));
+ assertThat(actual.getMessage(), is("Session attribution does not match
existing binding for session `session-1`."));
+ }
+
+ @Test
+ void assertCloseSessionRemovesSessionAttribution() {
+ MCPSessionManager sessionManager = new
MCPSessionManager(Collections.emptyMap());
+ sessionManager.createSession("session-1");
+ sessionManager.bindSessionAttribution("session-1", new
MCPSessionAttribution("subject", "gateway", Map.of()));
+ sessionManager.closeSession("session-1");
+ assertThat(sessionManager.findSessionAttribution("session-1"),
is(Optional.empty()));
+ }
}
diff --git
a/mcp/support/src/main/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicy.java
b/mcp/support/src/main/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicy.java
index 04b4f968068..d6f74352886 100644
---
a/mcp/support/src/main/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicy.java
+++
b/mcp/support/src/main/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicy.java
@@ -52,7 +52,10 @@ public final class MCPClientSafetyPolicy {
public static Map<String, Object> createModelFacingPayload() {
Map<String, Object> result = new LinkedHashMap<>(5, 1F);
result.put("identity_scope", "mcp_session");
- result.put("transport_scope", "HTTP transport has no built-in
authorization in this release; STDIO inherits the local process boundary.");
+ result.put("transport_scope",
+ "HTTP transport can bind trusted session attribution when
configured; "
+ + "authentication and authorization remain outside the
runtime in this release. "
+ + "STDIO inherits the local process boundary.");
result.put("tool_call_limit", createToolCallLimitPayload());
result.put("abuse_guard", "Every tool call is counted before dispatch,
including invalid calls, so runaway model loops stop at the session quota.");
result.put("external_model_boundary", "The MCP runtime never calls
external model providers; live LLM E2E clients call configured endpoints
outside the server.");
diff --git
a/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/descriptor/MCPModelFirstContractPayloadBuilderTest.java
b/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/descriptor/MCPModelFirstContractPayloadBuilderTest.java
index eb276b9a663..93a67c66645 100644
---
a/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/descriptor/MCPModelFirstContractPayloadBuilderTest.java
+++
b/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/descriptor/MCPModelFirstContractPayloadBuilderTest.java
@@ -85,7 +85,7 @@ class MCPModelFirstContractPayloadBuilderTest {
assertTrue(String.valueOf(actual.get("origin_header")).contains("loopback
origins"));
Map<?, ?> actualClientSafetyPolicy =
castToMap(actual.get("client_safety_policy"));
assertThat(actualClientSafetyPolicy.get("identity_scope"),
is("mcp_session"));
-
assertTrue(String.valueOf(actualClientSafetyPolicy.get("transport_scope")).contains("no
built-in authorization"));
+
assertTrue(String.valueOf(actualClientSafetyPolicy.get("transport_scope")).contains("trusted
session attribution"));
assertThat(castToMap(actualClientSafetyPolicy.get("tool_call_limit")).get("scope"),
is("session"));
assertTrue(String.valueOf(actualClientSafetyPolicy.get("external_model_boundary")).contains("never
calls external model providers"));
}
diff --git
a/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicyTest.java
b/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicyTest.java
index 5d3f6bbd1ed..8e910d02d25 100644
---
a/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicyTest.java
+++
b/mcp/support/src/test/java/org/apache/shardingsphere/mcp/support/security/MCPClientSafetyPolicyTest.java
@@ -44,6 +44,7 @@ class MCPClientSafetyPolicyTest {
void assertCreateModelFacingPayload() {
Map<String, Object> actual =
MCPClientSafetyPolicy.createModelFacingPayload();
assertThat(actual.get("identity_scope"), is("mcp_session"));
+
assertTrue(String.valueOf(actual.get("transport_scope")).contains("trusted
session attribution"));
assertTrue(String.valueOf(actual.get("external_model_boundary")).contains("never
calls external model providers"));
Map<?, ?> actualToolCallLimit = (Map<?, ?>)
actual.get("tool_call_limit");
assertThat(actualToolCallLimit.get("scope"), is("session"));