Lee-W commented on code in PR #35226:
URL: https://github.com/apache/airflow/pull/35226#discussion_r1376049250


##########
tests/providers/amazon/aws/sensors/test_batch.py:
##########
@@ -106,215 +134,167 @@ def 
test_execute_failure_in_deferrable_mode_with_soft_fail(self, deferrable_batc
         with pytest.raises(AirflowSkipException):
             deferrable_batch_sensor.execute_complete(context={}, 
event={"status": "failure"})
 
-    @pytest.mark.parametrize(
-        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
-    )
+    @pytest.mark.parametrize("soft_fail, expected_exception", SOFT_FAIL_CASES)
     @pytest.mark.parametrize(
         "state, error_message",
         (
-            (
+            pytest.param(
                 BatchClientHook.FAILURE_STATE,
                 f"Batch sensor failed. AWS Batch job status: 
{BatchClientHook.FAILURE_STATE}",
+                id="failure",
+            ),
+            pytest.param(
+                "INVALID", "Batch sensor failed. Unknown AWS Batch job status: 
INVALID", id="unknown"
             ),
-            ("unknown_state", "Batch sensor failed. Unknown AWS Batch job 
status: unknown_state"),
         ),
     )
-    @mock.patch.object(BatchClientHook, "get_job_description")
     def test_fail_poke(
         self,
         mock_get_job_description,
-        batch_sensor: BatchSensor,
         state,
-        error_message,
+        error_message: str,
         soft_fail,
         expected_exception,
     ):
         mock_get_job_description.return_value = {"status": state}
-        batch_sensor.soft_fail = soft_fail
+        batch_sensor = BatchSensor(**self.default_op_kwargs, 
soft_fail=soft_fail)
         with pytest.raises(expected_exception, match=error_message):
             batch_sensor.poke({})
+        mock_get_job_description.assert_called_once_with(JOB_ID)
 
 
-@pytest.fixture(scope="module")
-def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
-    return BatchComputeEnvironmentSensor(
-        task_id="test_batch_compute_environment_sensor",
-        compute_environment=ENVIRONMENT_NAME,
-    )
+class TestBatchComputeEnvironmentSensor(BaseBatchSensorsTests):
+    op_class = BatchComputeEnvironmentSensor
 
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self):
+        self.default_op_kwargs = dict(

Review Comment:
   same as the suggestion above



##########
tests/providers/amazon/aws/operators/test_batch.py:
##########
@@ -305,40 +337,55 @@ def test_monitor_job_with_logs(
 
 
 class TestBatchCreateComputeEnvironmentOperator:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self):

Review Comment:
   https://github.com/apache/airflow/blob/main/TESTING.rst#id5
   
   > We are in the process of converting all unit tests to standard "asserts" 
and pytest fixtures so if you find some tests that are still using classic 
setUp/tearDown approach or unittest asserts, feel free to convert them to 
pytest.
   
   Not sure whether we should usefixture in this case



##########
tests/providers/amazon/aws/sensors/test_batch.py:
##########
@@ -29,59 +30,86 @@
 )
 from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger
 
+if TYPE_CHECKING:
+    from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+
+
 TASK_ID = "batch_job_sensor"
 JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
 AWS_REGION = "eu-west-1"
 ENVIRONMENT_NAME = "environment_name"
 JOB_QUEUE = "job_queue"
 
-
-@pytest.fixture(scope="module")
-def batch_sensor() -> BatchSensor:
-    return BatchSensor(
-        task_id="batch_job_sensor",
-        job_id=JOB_ID,
-    )
+SOFT_FAIL_CASES = [
+    pytest.param(False, AirflowException, id="not-soft-fail"),
+    pytest.param(True, AirflowSkipException, id="soft-fail"),
+]
 
 
 @pytest.fixture(scope="module")
 def deferrable_batch_sensor() -> BatchSensor:
     return BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, 
deferrable=True)
 
 
-class TestBatchSensor:
-    @mock.patch.object(BatchClientHook, "get_job_description")
-    def test_poke_on_success_state(self, mock_get_job_description, 
batch_sensor: BatchSensor):
-        mock_get_job_description.return_value = {"status": "SUCCEEDED"}
-        assert batch_sensor.poke({}) is True
-        mock_get_job_description.assert_called_once_with(JOB_ID)
+@pytest.fixture
+def mock_get_job_description():
+    with mock.patch.object(BatchClientHook, "get_job_description") as m:
+        yield m
 
-    @mock.patch.object(BatchClientHook, "get_job_description")
-    def test_poke_on_failure_state(self, mock_get_job_description, 
batch_sensor: BatchSensor):
-        mock_get_job_description.return_value = {"status": "FAILED"}
-        with pytest.raises(AirflowException, match="Batch sensor failed. AWS 
Batch job status: FAILED"):
-            batch_sensor.poke({})
 
-        mock_get_job_description.assert_called_once_with(JOB_ID)
+@pytest.fixture
+def mock_batch_client():
+    with mock.patch.object(BatchClientHook, "client") as m:
+        yield m
+
+
+class BaseBatchSensorsTests:
+    """Base test class for Batch Sensors."""
+
+    op_class: type[AwsBaseSensor]
+    default_op_kwargs: dict[str, Any]
+
+    def test_base_aws_op_attributes(self):
+        op = self.op_class(**self.default_op_kwargs)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
+        op = self.op_class(
+            **self.default_op_kwargs,
+            aws_conn_id="aws-test-custom-conn",
+            region_name="eu-west-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
+        )
+        assert op.hook.aws_conn_id == "aws-test-custom-conn"
+        assert op.hook._region_name == "eu-west-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
 
-    @mock.patch.object(BatchClientHook, "get_job_description")
-    def test_poke_on_invalid_state(self, mock_get_job_description, 
batch_sensor: BatchSensor):
-        mock_get_job_description.return_value = {"status": "INVALID"}
-        with pytest.raises(
-            AirflowException, match="Batch sensor failed. Unknown AWS Batch 
job status: INVALID"
-        ):
-            batch_sensor.poke({})
 
+class TestBatchSensor(BaseBatchSensorsTests):
+    op_class = BatchSensor
+
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self):
+        self.default_op_kwargs = dict(
+            task_id="test_batch_sensor",
+            job_id=JOB_ID,
+        )

Review Comment:
   According to https://github.com/apache/airflow/pull/33761/files, we probably 
would like to change it to 
   
   ```suggestion
           self.default_op_kwargs = {
               "task_id": "test_batch_sensor",
               "job_id": JOB_ID,
           }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to