This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 381160c0f6 Fix template rendered bucket_key in S3KeySensor (#28340) 381160c0f6 is described below commit 381160c0f63a15957a631da9db875f98bb8e9d64 Author: Sung Yun <107272191+syu...@users.noreply.github.com> AuthorDate: Wed Dec 14 02:47:46 2022 -0500 Fix template rendered bucket_key in S3KeySensor (#28340) --- airflow/providers/amazon/aws/sensors/s3.py | 7 ++++-- tests/providers/amazon/aws/sensors/test_s3_key.py | 27 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 40e31596d7..57c9393c0c 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -86,7 +86,7 @@ class S3KeySensor(BaseSensorOperator): ): super().__init__(**kwargs) self.bucket_name = bucket_name - self.bucket_key = [bucket_key] if isinstance(bucket_key, str) else bucket_key + self.bucket_key = bucket_key self.wildcard_match = wildcard_match self.check_fn = check_fn self.aws_conn_id = aws_conn_id @@ -125,7 +125,10 @@ class S3KeySensor(BaseSensorOperator): return True def poke(self, context: Context): - return all(self._check_key(key) for key in self.bucket_key) + if isinstance(self.bucket_key, str): + return self._check_key(self.bucket_key) + else: + return all(self._check_key(key) for key in self.bucket_key) def get_hook(self) -> S3Hook: """Create and return an S3Hook""" diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py index 8d560e2c82..f0832d3df9 100644 --- a/tests/providers/amazon/aws/sensors/test_s3_key.py +++ b/tests/providers/amazon/aws/sensors/test_s3_key.py @@ -126,6 +126,33 @@ class TestS3KeySensor: mock_head_object.assert_called_once_with("key", "bucket") + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object") + def test_parse_list_of_bucket_keys_from_jinja(self, mock_head_object): + mock_head_object.return_value = None + mock_head_object.side_effect = [{"ContentLength": 0}, {"ContentLength": 0}] + + Variable.set("test_bucket_key", ["s3://bucket/file1", "s3://bucket/file2"]) + + execution_date = timezone.datetime(2020, 1, 1) + + dag = DAG("test_s3_key", start_date=execution_date, render_template_as_native_obj=True) + op = S3KeySensor( + task_id="s3_key_sensor", + bucket_key="{{ var.value.test_bucket_key }}", + bucket_name=None, + dag=dag, + ) + + dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date, run_id="test") + ti = TaskInstance(task=op) + ti.dag_run = dag_run + context = ti.get_template_context() + ti.render_templates(context) + op.poke(None) + + mock_head_object.assert_any_call("file1", "bucket") + mock_head_object.assert_any_call("file2", "bucket") + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.head_object") def test_poke(self, mock_head_object): op = S3KeySensor(task_id="s3_key_sensor", bucket_key="s3://test_bucket/file")