This is an automated email from the ASF dual-hosted git repository.
o-nikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8cd923206d1 Update Notebook Operator to pull S3 project bucket from
tooling env (#67915)
8cd923206d1 is described below
commit 8cd923206d1ab439037a2e11352304f7b750720c
Author: Qazi Ashikin <[email protected]>
AuthorDate: Wed Jun 3 19:56:44 2026 -0400
Update Notebook Operator to pull S3 project bucket from tooling env (#67915)
Resolve the bucket the same way the SMUS Workflows UI does:
List the project's "Tooling" environment via DataZone
ListEnvironmentBlueprints + ListEnvironments.
Fall back to "ToolingLite" when no Tooling environment exists.
GetEnvironment to read provisionedResources, and parse the bucket name
from the s3BucketPath resource.
The operator and sensor now pass domain_identifier through to
get_notebook_outputs so the lookup has both identifiers it needs.
---
.../docs/operators/sagemakerunifiedstudio.rst | 17 +-
.../aws/hooks/sagemaker_unified_studio_notebook.py | 166 +++++++++-
.../operators/sagemaker_unified_studio_notebook.py | 1 +
.../sensors/sagemaker_unified_studio_notebook.py | 1 +
.../test_sagemaker_unified_studio_notebook.py | 351 +++++++++++++++------
.../test_sagemaker_unified_studio_notebook.py | 1 +
.../test_sagemaker_unified_studio_notebook.py | 1 +
7 files changed, 425 insertions(+), 113 deletions(-)
diff --git a/providers/amazon/docs/operators/sagemakerunifiedstudio.rst
b/providers/amazon/docs/operators/sagemakerunifiedstudio.rst
index 0f6510633f8..cee556bf724 100644
--- a/providers/amazon/docs/operators/sagemakerunifiedstudio.rst
+++ b/providers/amazon/docs/operators/sagemakerunifiedstudio.rst
@@ -56,6 +56,14 @@ The artifact is identified by its relative file path within
the project (e.g. ``
:start-after: [START howto_operator_sagemaker_unified_studio_notebook]
:end-before: [END howto_operator_sagemaker_unified_studio_notebook]
+The following example adds domain ID, project ID, and domain name as operator
parameters.
+
+.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
howto_operator_sagemaker_unified_studio_notebook_explicit_params]
+ :end-before: [END
howto_operator_sagemaker_unified_studio_notebook_explicit_params]
+
.. _howto/operator:SageMakerUnifiedStudioNotebookOperator:
Run SageMaker Unified Studio notebooks
@@ -73,15 +81,6 @@ where the notebook resides.
:start-after: [START howto_operator_sagemaker_unified_studio_notebook]
:end-before: [END howto_operator_sagemaker_unified_studio_notebook]
-
-The following example adds domain ID, project ID, and domain name as operator
parameters.
-
-.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py
- :language: python
- :dedent: 4
- :start-after: [START
howto_operator_sagemaker_unified_studio_notebook_explicit_params]
- :end-before: [END
howto_operator_sagemaker_unified_studio_notebook_explicit_params]
-
Notebooks can produce output variables that are automatically pushed to XCom
when the run completes.
Downstream tasks can consume these outputs via Jinja templating in
``notebook_parameters``.
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py
index 779eaf45554..20ed57b3a2b 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py
@@ -26,6 +26,7 @@ import time
import uuid
from functools import cached_property
from typing import Any
+from urllib.parse import urlparse
from botocore.exceptions import ClientError
@@ -246,21 +247,160 @@ class SageMakerUnifiedStudioNotebookHook(AwsBaseHook):
error_message = execution_message
raise RuntimeError(error_message)
- def get_project_s3_path(self, project_id: str) -> str:
+ def get_project_s3_path(self, domain_identifier: str, project_id: str) ->
str:
"""
- Construct the S3 path for a SageMaker Unified Studio project bucket.
+ Look up the S3 bucket path for a SageMaker Unified Studio project.
+ The bucket path is read from the ``s3BucketPath`` provisioned resource
of
+ the project's default ("Tooling") environment via the DataZone APIs:
+ ``GetEnvironment(GetProjectDefaultEnvironment(...))``. This mirrors how
+ SageMaker Unified Studio resolves the project bucket, and accommodates
projects
+ whose bucket name does not follow the
+ ``amazon-sagemaker-{account_id}-{region}-{project_id}`` template (for
+ example, BYOR-bucket projects).
+
+ :param domain_identifier: The ID of the DataZone domain.
:param project_id: The ID of the DataZone project.
:return: The S3 bucket name for the project.
+ :raises RuntimeError: If the default tooling environment or the
+ ``s3BucketPath`` provisioned resource cannot be found.
+ """
+ environment = self._get_default_tooling_environment(domain_identifier,
project_id)
+ environment_id = environment.get("id")
+ provisioned_resources = environment.get("provisionedResources", []) or
[]
+ for resource in provisioned_resources:
+ if resource.get("name") == "s3BucketPath":
+ value = resource.get("value")
+ if not value:
+ raise RuntimeError(
+ f"s3BucketPath provisioned resource is empty in
default tooling "
+ f"environment {environment_id} for project
{project_id} in domain "
+ f"{domain_identifier}"
+ )
+ # value looks like "s3://<bucket>/<prefix>"; return the bucket
name only.
+ parts = urlparse(value, allow_fragments=False)
+ bucket = parts.netloc
+ if not bucket:
+ raise RuntimeError(
+ f"s3BucketPath provisioned resource has unexpected
format "
+ f"'{value}' in default tooling environment
{environment_id} for "
+ f"project {project_id} in domain {domain_identifier}"
+ )
+ return bucket
+
+ raise RuntimeError(
+ f"s3BucketPath provisioned resource not found in default tooling
environment "
+ f"{environment_id} for project {project_id} in domain
{domain_identifier}"
+ )
+
+ def _get_default_tooling_environment(self, domain_identifier: str,
project_id: str) -> dict:
+ """
+ Resolve the project's default ("Tooling") environment via DataZone
APIs.
+
+ 1. ``ListEnvironmentBlueprints(managed=True, name="Tooling")`` →
+ resolve the Tooling blueprint id.
+ 2. ``ListEnvironments(environmentBlueprintIdentifier=...)`` →
+ list the project's tooling environments.
+ 3. Pick the environment with the smallest non-null ``deploymentOrder``
+ as the default. If none has one, fall back to the ``ToolingLite``
+ blueprint with the same logic.
+ 4. ``GetEnvironment(identifier=...)`` → read the full record (including
+ ``provisionedResources``).
+
+ :param domain_identifier: The ID of the DataZone domain.
+ :param project_id: The ID of the DataZone project.
+ :return: The full environment dict from ``GetEnvironment``.
+ :raises RuntimeError: If no default Tooling/ToolingLite environment is
+ found or the DataZone APIs return an error.
"""
- account_id = self.account_id
- region = self.conn_region_name
- return f"amazon-sagemaker-{account_id}-{region}-{project_id}"
+ try:
+ default_env_summary =
self._find_default_tooling_environment_summary(
+ domain_identifier=domain_identifier,
+ project_id=project_id,
+ blueprint_name="Tooling",
+ )
+ if default_env_summary is None:
+ default_env_summary =
self._find_default_tooling_environment_summary(
+ domain_identifier=domain_identifier,
+ project_id=project_id,
+ blueprint_name="ToolingLite",
+ )
+ if default_env_summary is None:
+ raise RuntimeError(
+ f"No default Tooling or ToolingLite environment found for
project "
+ f"{project_id} in domain {domain_identifier}"
+ )
+
+ return self.conn.get_environment(
+ domainIdentifier=domain_identifier,
+ identifier=default_env_summary["id"],
+ )
+ except ClientError as e:
+ raise RuntimeError(
+ f"Failed to resolve default tooling environment for project
{project_id} "
+ f"in domain {domain_identifier}: {e}"
+ ) from e
+
+ def _find_default_tooling_environment_summary(
+ self,
+ domain_identifier: str,
+ project_id: str,
+ blueprint_name: str,
+ ) -> dict | None:
+ """
+ Resolve the default tooling environment summary for a given blueprint.
+
+ Returns ``None`` when the blueprint has no environments for the project
+ (so the caller can fall back to ``ToolingLite``). When environments
+ exist, prefers the one with the lowest non-null ``deploymentOrder``;
+ when ``deploymentOrder`` is absent on every env (the field is optional
+ in the DataZone response shape), falls back to the first item.
+
+ Raises ``RuntimeError`` only when the blueprint itself is missing.
+
+ :param domain_identifier: The ID of the DataZone domain.
+ :param project_id: The ID of the DataZone project.
+ :param blueprint_name: ``"Tooling"`` or ``"ToolingLite"``.
+ :return: The environment summary dict, or ``None``.
+ """
+ blueprints = (
+ self.conn.list_environment_blueprints(
+ domainIdentifier=domain_identifier,
+ managed=True,
+ name=blueprint_name,
+ ).get("items", [])
+ or []
+ )
+ if not blueprints:
+ raise RuntimeError(
+ f"{blueprint_name} environment blueprint not found in domain
{domain_identifier}"
+ )
+ blueprint_id = blueprints[0]["id"]
+
+ environments = (
+ self.conn.list_environments(
+ domainIdentifier=domain_identifier,
+ projectIdentifier=project_id,
+ environmentBlueprintIdentifier=blueprint_id,
+ ).get("items", [])
+ or []
+ )
+
+ if not environments:
+ return None
+
+ ordered = [env for env in environments if env.get("deploymentOrder")
is not None]
+ if ordered:
+ return min(ordered, key=lambda env: env["deploymentOrder"])
+ # ``deploymentOrder`` is optional in the EnvironmentSummary shape; when
+ # absent on every item, fall back to the first env for this blueprint.
+ return environments[0]
def get_notebook_outputs(
self,
notebook_identifier: str,
notebook_run_id: str,
+ domain_identifier: str,
owning_project_identifier: str,
) -> dict[str, Any]:
"""
@@ -272,14 +412,26 @@ class SageMakerUnifiedStudioNotebookHook(AwsBaseHook):
:param notebook_identifier: The ID of the notebook that was executed.
:param notebook_run_id: The ID of the completed notebook run.
+ :param domain_identifier: The ID of the DataZone domain.
:param owning_project_identifier: The ID of the DataZone project.
:return: A dict of notebook output key-value pairs. Returns an empty
dict
if no outputs were written or the file cannot be parsed.
"""
- bucket = self.get_project_s3_path(owning_project_identifier)
+ log = logging.getLogger(__name__)
+ try:
+ bucket = self.get_project_s3_path(domain_identifier,
owning_project_identifier)
+ except Exception:
+ log.warning(
+ "Failed to resolve project S3 bucket for project %s in domain
%s, "
+ "skipping notebook outputs read.",
+ owning_project_identifier,
+ domain_identifier,
+ exc_info=True,
+ )
+ return {}
+
key =
f"sys/notebooks/{notebook_identifier}/runs/{notebook_run_id}/notebook_outputs.json"
- log = logging.getLogger(__name__)
log.info("Reading notebook outputs from s3://%s/%s", bucket, key)
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id,
region_name=self.conn_region_name)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py
index 8d5558c9c45..e10ed90a412 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py
@@ -157,6 +157,7 @@ class
SageMakerUnifiedStudioNotebookOperator(AwsBaseOperator[SageMakerUnifiedStu
outputs = self.hook.get_notebook_outputs(
notebook_identifier=self.notebook_identifier,
notebook_run_id=notebook_run_id,
+ domain_identifier=self.domain_identifier,
owning_project_identifier=self.owning_project_identifier,
)
if outputs:
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio_notebook.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio_notebook.py
index 0d59fe99ecc..c5dc4fc13f4 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio_notebook.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio_notebook.py
@@ -126,6 +126,7 @@ class
SageMakerUnifiedStudioNotebookSensor(AwsBaseSensor[SageMakerUnifiedStudioN
outputs = self.hook.get_notebook_outputs(
notebook_identifier=self.notebook_identifier,
notebook_run_id=self.notebook_run_id,
+ domain_identifier=self.domain_identifier,
owning_project_identifier=self.owning_project_identifier,
)
if outputs:
diff --git
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py
index 7e8040d99ae..92023431161 100644
---
a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py
+++
b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py
@@ -29,8 +29,6 @@ DOMAIN_ID = "dzd_example"
PROJECT_ID = "proj_example"
NOTEBOOK_ID = "notebook_123"
NOTEBOOK_RUN_ID = "run_456"
-ACCOUNT_ID = "123456789012"
-REGION = "us-west-2"
HOOK_MODULE =
"airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook"
@@ -284,73 +282,249 @@ class TestSageMakerUnifiedStudioNotebookHook:
# --- get_project_s3_path ---
- def test_get_project_s3_path(self):
- """Constructs the correct S3 bucket name from account_id, region, and
project_id."""
- with (
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
- result = self.hook.get_project_s3_path(PROJECT_ID)
- assert result == f"amazon-sagemaker-{ACCOUNT_ID}-{REGION}-{PROJECT_ID}"
+ def _stub_tooling_blueprint_lookup(
+ self,
+ environments: list[dict] | None = None,
+ tooling_lite_environments: list[dict] | None = None,
+ ) -> None:
+ """Set up list_environment_blueprints + list_environments for the
tooling lookup.
+
+ Returns the Tooling blueprint with id ``bp-tooling`` and lists
``environments``
+ for it. If ``tooling_lite_environments`` is provided, also returns a
+ ToolingLite blueprint (``bp-tooling-lite``) and lists those for it.
+ """
+ environments = environments or []
+
+ def list_envs(**kwargs):
+ blueprint_id = kwargs.get("environmentBlueprintIdentifier")
+ if blueprint_id == "bp-tooling":
+ return {"items": environments}
+ if blueprint_id == "bp-tooling-lite":
+ return {"items": tooling_lite_environments or []}
+ return {"items": []}
+
+ def list_blueprints(**kwargs):
+ name = kwargs.get("name")
+ if name == "Tooling":
+ return {"items": [{"id": "bp-tooling", "name": "Tooling"}]}
+ if name == "ToolingLite":
+ return {"items": [{"id": "bp-tooling-lite", "name":
"ToolingLite"}]}
+ return {"items": []}
+
+ self.mock_client.list_environment_blueprints.side_effect =
list_blueprints
+ self.mock_client.list_environments.side_effect = list_envs
+
+ def test_get_project_s3_path_uses_default_tooling_environment(self):
+ """Resolves bucket name from the default tooling environment's
s3BucketPath."""
+ env_id = "env-tooling-1"
+ bucket = "my-byor-bucket"
+ self._stub_tooling_blueprint_lookup(
+ environments=[
+ {"id": "env-other", "name": "Tooling", "deploymentOrder": 5},
+ {"id": env_id, "name": "Tooling", "deploymentOrder": 1},
+ ]
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": env_id,
+ "provisionedResources": [
+ {"name": "userRoleArn", "value": "arn:aws:iam::123:role/foo"},
+ {"name": "s3BucketPath", "value":
f"s3://{bucket}/dzd_x/{PROJECT_ID}/dev"},
+ ],
+ }
+
+ result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ assert result == bucket
+ self.mock_client.list_environment_blueprints.assert_any_call(
+ domainIdentifier=DOMAIN_ID,
+ managed=True,
+ name="Tooling",
+ )
+ self.mock_client.list_environments.assert_called_once_with(
+ domainIdentifier=DOMAIN_ID,
+ projectIdentifier=PROJECT_ID,
+ environmentBlueprintIdentifier="bp-tooling",
+ )
+ self.mock_client.get_environment.assert_called_once_with(
+ domainIdentifier=DOMAIN_ID,
+ identifier=env_id,
+ )
+
+ def test_get_project_s3_path_picks_lowest_deployment_order(self):
+ """Picks the env with the lowest non-null deploymentOrder, ignoring
None."""
+ env_id = "env-tooling-default"
+ bucket = "default-bucket"
+ self._stub_tooling_blueprint_lookup(
+ environments=[
+ {"id": "env-no-order", "name": "Tooling", "deploymentOrder":
None},
+ {"id": "env-other", "name": "Tooling", "deploymentOrder": 9},
+ {"id": env_id, "name": "Tooling", "deploymentOrder": 1},
+ ]
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": env_id,
+ "provisionedResources": [{"name": "s3BucketPath", "value":
f"s3://{bucket}/p"}],
+ }
+
+ result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ assert result == bucket
+ self.mock_client.get_environment.assert_called_once_with(
+ domainIdentifier=DOMAIN_ID,
+ identifier=env_id,
+ )
+
+ def test_get_project_s3_path_falls_back_to_tooling_lite(self):
+ """Falls back to ToolingLite when the Tooling blueprint has no envs."""
+ env_id = "env-lite-1"
+ bucket = "lite-bucket"
+ self._stub_tooling_blueprint_lookup(
+ environments=[],
+ tooling_lite_environments=[{"id": env_id, "name": "ToolingLite",
"deploymentOrder": 1}],
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": env_id,
+ "provisionedResources": [{"name": "s3BucketPath", "value":
f"s3://{bucket}/p"}],
+ }
+
+ result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ assert result == bucket
+ # Both blueprint lookups happened.
+ assert self.mock_client.list_environment_blueprints.call_count == 2
+ self.mock_client.get_environment.assert_called_once_with(
+ domainIdentifier=DOMAIN_ID,
+ identifier=env_id,
+ )
+
+ def
test_get_project_s3_path_falls_back_to_first_when_no_deployment_order(self):
+ """When envs exist but none has deploymentOrder, returns the first
env."""
+ env_id = "env-tooling-1"
+ bucket = "first-bucket"
+ self._stub_tooling_blueprint_lookup(
+ environments=[
+ {"id": env_id, "name": "AmazonSagemakerEnvironmentConfig-x"},
+ {"id": "env-tooling-2", "name":
"AmazonSagemakerEnvironmentConfig-y"},
+ ]
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": env_id,
+ "provisionedResources": [{"name": "s3BucketPath", "value":
f"s3://{bucket}/p"}],
+ }
+
+ result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ assert result == bucket
+ self.mock_client.get_environment.assert_called_once_with(
+ domainIdentifier=DOMAIN_ID,
+ identifier=env_id,
+ )
+
+ def
test_get_project_s3_path_raises_when_no_environments_for_either_blueprint(self):
+ """Raises RuntimeError when neither Tooling nor ToolingLite has any
envs."""
+ self._stub_tooling_blueprint_lookup(environments=[],
tooling_lite_environments=[])
+ with pytest.raises(RuntimeError, match="No default Tooling or
ToolingLite environment found"):
+ self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ def test_get_project_s3_path_raises_when_blueprint_missing(self):
+ """Raises RuntimeError when the Tooling blueprint is not registered in
the domain."""
+ self.mock_client.list_environment_blueprints.return_value = {"items":
[]}
+ with pytest.raises(RuntimeError, match="Tooling environment blueprint
not found"):
+ self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ def test_get_project_s3_path_raises_when_resource_missing(self):
+ """Raises RuntimeError when s3BucketPath is not in
provisionedResources."""
+ self._stub_tooling_blueprint_lookup(
+ environments=[{"id": "env-1", "name": "Tooling",
"deploymentOrder": 1}]
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": "env-1",
+ "provisionedResources": [{"name": "userRoleArn", "value":
"arn:aws:iam::123:role/foo"}],
+ }
+ with pytest.raises(RuntimeError, match="s3BucketPath provisioned
resource not found"):
+ self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ def test_get_project_s3_path_raises_when_resource_value_empty(self):
+ """Raises RuntimeError when s3BucketPath is present but empty."""
+ self._stub_tooling_blueprint_lookup(
+ environments=[{"id": "env-1", "name": "Tooling",
"deploymentOrder": 1}]
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": "env-1",
+ "provisionedResources": [{"name": "s3BucketPath", "value": ""}],
+ }
+ with pytest.raises(RuntimeError, match="s3BucketPath provisioned
resource is empty"):
+ self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ def test_get_project_s3_path_raises_on_malformed_uri(self):
+ """Raises RuntimeError when s3BucketPath has an unexpected format."""
+ self._stub_tooling_blueprint_lookup(
+ environments=[{"id": "env-1", "name": "Tooling",
"deploymentOrder": 1}]
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": "env-1",
+ "provisionedResources": [{"name": "s3BucketPath", "value":
"not-an-s3-uri"}],
+ }
+ with pytest.raises(RuntimeError, match="unexpected format"):
+ self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
+
+ def test_get_project_s3_path_wraps_client_error(self):
+ """Wraps boto ClientError from DataZone APIs in RuntimeError."""
+ from botocore.exceptions import ClientError
+
+ error_response = {"Error": {"Code": "AccessDenied", "Message": "no
access"}}
+ self.mock_client.list_environment_blueprints.side_effect = ClientError(
+ error_response, "ListEnvironmentBlueprints"
+ )
+ with pytest.raises(RuntimeError, match="Failed to resolve default
tooling environment"):
+ self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)
# --- get_notebook_outputs ---
+ def _stub_project_bucket(self, bucket: str = "test-bucket") -> None:
+ """Configure the mock client to resolve the project bucket to
``bucket``."""
+ self._stub_tooling_blueprint_lookup(
+ environments=[{"id": "env-1", "name": "Tooling",
"deploymentOrder": 1}]
+ )
+ self.mock_client.get_environment.return_value = {
+ "id": "env-1",
+ "provisionedResources": [
+ {"name": "s3BucketPath", "value":
f"s3://{bucket}/dzd_x/{PROJECT_ID}/dev"},
+ ],
+ }
+
def test_get_notebook_outputs_success(self):
"""Reads and parses JSON outputs from S3."""
outputs = {"name": "Alice", "age": 42}
+ bucket = "test-bucket"
+ self._stub_project_bucket(bucket)
- with (
- patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.return_value =
json.dumps(outputs)
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
assert result == outputs
- expected_bucket =
f"amazon-sagemaker-{ACCOUNT_ID}-{REGION}-{PROJECT_ID}"
expected_key =
f"sys/notebooks/{NOTEBOOK_ID}/runs/{NOTEBOOK_RUN_ID}/notebook_outputs.json"
- mock_s3_hook_cls.return_value.read_key.assert_called_once_with(
- key=expected_key, bucket_name=expected_bucket
- )
+
mock_s3_hook_cls.return_value.read_key.assert_called_once_with(key=expected_key,
bucket_name=bucket)
def test_get_notebook_outputs_no_such_key(self):
"""Returns empty dict when the outputs file does not exist in S3."""
from botocore.exceptions import ClientError
error_response = {"Error": {"Code": "NoSuchKey", "Message": "Not
found"}}
+ self._stub_project_bucket()
- with (
- patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.side_effect =
ClientError(error_response, "GetObject")
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
@@ -361,22 +535,14 @@ class TestSageMakerUnifiedStudioNotebookHook:
from botocore.exceptions import ClientError
error_response = {"Error": {"Code": "404", "Message": "Not Found"}}
+ self._stub_project_bucket()
- with (
- patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.side_effect =
ClientError(error_response, "HeadObject")
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
@@ -384,21 +550,14 @@ class TestSageMakerUnifiedStudioNotebookHook:
def test_get_notebook_outputs_invalid_json(self):
"""Returns empty dict when S3 file contains invalid JSON."""
- with (
- patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
+ self._stub_project_bucket()
+
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.return_value = "not valid
json {{{"
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
@@ -406,21 +565,14 @@ class TestSageMakerUnifiedStudioNotebookHook:
def test_get_notebook_outputs_non_dict_json(self):
"""Returns empty dict when S3 file contains valid JSON but not a
dict."""
- with (
- patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
+ self._stub_project_bucket()
+
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.return_value =
json.dumps(["a", "b"])
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
@@ -428,21 +580,14 @@ class TestSageMakerUnifiedStudioNotebookHook:
def test_get_notebook_outputs_unexpected_exception(self):
"""Returns empty dict on unexpected S3 errors."""
- with (
- patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
+ self._stub_project_bucket()
+
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.side_effect =
ConnectionError("Network issue")
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
@@ -450,22 +595,34 @@ class TestSageMakerUnifiedStudioNotebookHook:
def test_get_notebook_outputs_empty_dict(self):
"""Returns empty dict when S3 file contains an empty JSON object."""
- with (
- patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "account_id",
new_callable=PropertyMock
- ) as mock_account,
- patch.object(
- SageMakerUnifiedStudioNotebookHook, "conn_region_name",
new_callable=PropertyMock
- ) as mock_region,
- ):
- mock_account.return_value = ACCOUNT_ID
- mock_region.return_value = REGION
+ self._stub_project_bucket()
+
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.return_value =
json.dumps({})
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
+ owning_project_identifier=PROJECT_ID,
+ )
+
+ assert result == {}
+
+ def
test_get_notebook_outputs_returns_empty_when_bucket_resolution_fails(self):
+ """Returns empty dict when DataZone APIs fail to resolve the project
bucket."""
+ from botocore.exceptions import ClientError
+
+ error_response = {"Error": {"Code": "AccessDenied", "Message": "no
access"}}
+ self.mock_client.list_environments.side_effect =
ClientError(error_response, "ListEnvironments")
+
+ with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
+ result = self.hook.get_notebook_outputs(
+ notebook_identifier=NOTEBOOK_ID,
+ notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
assert result == {}
+ # S3Hook is not even instantiated when bucket cannot be resolved.
+ mock_s3_hook_cls.assert_not_called()
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio_notebook.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio_notebook.py
index c02640e8983..1233afb490c 100644
---
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio_notebook.py
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio_notebook.py
@@ -378,6 +378,7 @@ class TestSageMakerUnifiedStudioNotebookOperator:
mock_hook.get_notebook_outputs.assert_called_once_with(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
context["ti"].xcom_push.assert_any_call(key="notebook_run_id",
value=NOTEBOOK_RUN_ID)
diff --git
a/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio_notebook.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio_notebook.py
index 2f1be0e39d9..7b309777460 100644
---
a/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio_notebook.py
+++
b/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio_notebook.py
@@ -219,6 +219,7 @@ class TestSageMakerUnifiedStudioNotebookSensor:
mock_hook.get_notebook_outputs.assert_called_once_with(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
+ domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)
context["ti"].xcom_push.assert_called_once_with(key=f"{NOTEBOOK_OUTPUT_PREFIX}.name",
value="Alice")