o-nikolas commented on code in PR #28282:
URL: https://github.com/apache/airflow/pull/28282#discussion_r1046337674


##########
airflow/providers/amazon/aws/hooks/emr.py:
##########
@@ -202,6 +202,74 @@ def get_ui_field_behaviour() -> dict[str, Any]:
             },
         }
 
+    def is_cluster_available(self, emr_cluster_id, cluster_states):

Review Comment:
   All of these methods are missing doc strings and this one is missing typing.
   
   I also don't see this method used anywhere?



##########
airflow/providers/amazon/aws/hooks/emr.py:
##########
@@ -202,6 +202,74 @@ def get_ui_field_behaviour() -> dict[str, Any]:
             },
         }
 
+    def is_cluster_available(self, emr_cluster_id, cluster_states):
+        response = self.get_conn().list_clusters(ClusterStates=cluster_states)
+        matching_clusters = list(
+            filter(lambda cluster: cluster["Id"] == emr_cluster_id, 
response["Clusters"])
+        )
+
+        if len(matching_clusters) == 1:
+            emr_cluster_name = matching_clusters[0]["Name"]
+            self.log.info("Found cluster name = %s id = %s", emr_cluster_name, 
emr_cluster_id)
+            return True
+        elif len(matching_clusters) > 1:
+            raise AirflowException(f"More than one cluster found for Id 
{emr_cluster_id}")
+        else:
+            self.log.info("No cluster found for Id %s", emr_cluster_id)
+            return False
+
+    def _get_list_of_steps_already_triggered(
+        self, cluster_id: str, step_states: list[str]
+    ) -> list[tuple[str, str]]:
+
+        response = self.get_conn().list_steps(
+            ClusterId=cluster_id,
+            StepStates=step_states,
+        )
+        steps_name_id = [(step["Name"], step["Id"]) for step in 
response["Steps"]]
+        print(steps_name_id)
+        return steps_name_id
+
+    def _cancel_list_of_steps_already_triggered(
+        self, steps: list[dict], cluster_id: str, step_states: list[str]
+    ):
+        names_list = self._get_list_of_steps_already_triggered(cluster_id, 
step_states)
+
+        self.log.info(steps)

Review Comment:
   This list of steps is being logged with no context same as on line 245 and 
246. They also feel like `debug` level logs



##########
airflow/providers/amazon/aws/operators/emr.py:
##########
@@ -71,6 +71,9 @@ def __init__(
         aws_conn_id: str = "aws_default",
         steps: list[dict] | str | None = None,
         wait_for_completion: bool = False,
+        cancel_existing_steps: bool = True,
+        steps_states: list[str],
+        cancellation_option: str = "SEND_INTERRUPT",

Review Comment:
   Can you update the docstring above with these new params



##########
airflow/providers/amazon/aws/hooks/emr.py:
##########
@@ -202,6 +202,74 @@ def get_ui_field_behaviour() -> dict[str, Any]:
             },
         }
 
+    def is_cluster_available(self, emr_cluster_id, cluster_states):
+        response = self.get_conn().list_clusters(ClusterStates=cluster_states)
+        matching_clusters = list(
+            filter(lambda cluster: cluster["Id"] == emr_cluster_id, 
response["Clusters"])
+        )
+
+        if len(matching_clusters) == 1:
+            emr_cluster_name = matching_clusters[0]["Name"]
+            self.log.info("Found cluster name = %s id = %s", emr_cluster_name, 
emr_cluster_id)
+            return True
+        elif len(matching_clusters) > 1:
+            raise AirflowException(f"More than one cluster found for Id 
{emr_cluster_id}")
+        else:
+            self.log.info("No cluster found for Id %s", emr_cluster_id)
+            return False
+
+    def _get_list_of_steps_already_triggered(
+        self, cluster_id: str, step_states: list[str]
+    ) -> list[tuple[str, str]]:
+
+        response = self.get_conn().list_steps(
+            ClusterId=cluster_id,
+            StepStates=step_states,
+        )
+        steps_name_id = [(step["Name"], step["Id"]) for step in 
response["Steps"]]
+        print(steps_name_id)
+        return steps_name_id
+
+    def _cancel_list_of_steps_already_triggered(
+        self, steps: list[dict], cluster_id: str, step_states: list[str]
+    ):
+        names_list = self._get_list_of_steps_already_triggered(cluster_id, 
step_states)
+
+        self.log.info(steps)
+
+        steps_name_list = [step["Name"] for step in steps if "Name" in step]

Review Comment:
   Mostly just curious, can steps ever not have a name?



##########
tests/providers/amazon/aws/hooks/test_emr.py:
##########
@@ -190,3 +190,100 @@ def test_get_cluster_id_by_name(self):
         no_match = hook.get_cluster_id_by_name("foo", ["RUNNING", "WAITING", 
"BOOTSTRAPPING"])
 
         assert no_match is None
+
+    @mock_emr
+    def test_send_cancel_steps_first_invocation(self):
+        """
+        Test that we can resolve cluster id by cluster name.

Review Comment:
   This looks like copy/paste (same for the below method).



##########
airflow/providers/amazon/aws/hooks/emr.py:
##########
@@ -202,6 +202,74 @@ def get_ui_field_behaviour() -> dict[str, Any]:
             },
         }
 
+    def is_cluster_available(self, emr_cluster_id, cluster_states):
+        response = self.get_conn().list_clusters(ClusterStates=cluster_states)
+        matching_clusters = list(
+            filter(lambda cluster: cluster["Id"] == emr_cluster_id, 
response["Clusters"])
+        )
+
+        if len(matching_clusters) == 1:
+            emr_cluster_name = matching_clusters[0]["Name"]
+            self.log.info("Found cluster name = %s id = %s", emr_cluster_name, 
emr_cluster_id)
+            return True
+        elif len(matching_clusters) > 1:
+            raise AirflowException(f"More than one cluster found for Id 
{emr_cluster_id}")
+        else:
+            self.log.info("No cluster found for Id %s", emr_cluster_id)
+            return False
+
+    def _get_list_of_steps_already_triggered(
+        self, cluster_id: str, step_states: list[str]
+    ) -> list[tuple[str, str]]:
+
+        response = self.get_conn().list_steps(
+            ClusterId=cluster_id,
+            StepStates=step_states,
+        )
+        steps_name_id = [(step["Name"], step["Id"]) for step in 
response["Steps"]]
+        print(steps_name_id)

Review Comment:
   debug print left in the code.



##########
tests/providers/amazon/aws/hooks/test_emr.py:
##########
@@ -190,3 +190,100 @@ def test_get_cluster_id_by_name(self):
         no_match = hook.get_cluster_id_by_name("foo", ["RUNNING", "WAITING", 
"BOOTSTRAPPING"])
 
         assert no_match is None
+
+    @mock_emr
+    def test_send_cancel_steps_first_invocation(self):
+        """
+        Test that we can resolve cluster id by cluster name.
+        """
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default")
+
+        job_flow = hook.create_job_flow(
+            {"Name": "test_cluster", "Instances": 
{"KeepJobFlowAliveWhenNoSteps": True}}
+        )
+
+        job_flow_id = job_flow["JobFlowId"]
+
+        step = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": "step_1",
+            }
+        ]
+
+        did_not_execute_response = hook.send_cancel_steps(
+            steps_states=["PENDING", "RUNNING"],
+            emr_cluster_id=job_flow_id,
+            cancellation_option="SEND_INTERRUPT",
+            steps=step,
+        )
+
+        assert did_not_execute_response is None
+
+    @mock_emr
+    @pytest.mark.parametrize("num_steps", [1, 2, 3, 4])
+    def test_send_cancel_steps_on_pre_existing_step_name(self, num_steps):
+        """
+        Test that we can resolve cluster id by cluster name.
+        """
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default")
+
+        job_flow = hook.create_job_flow(
+            {"Name": "test_cluster", "Instances": 
{"KeepJobFlowAliveWhenNoSteps": True}}
+        )
+
+        job_flow_id = job_flow["JobFlowId"]
+
+        steps = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": f"step_{i}",
+            }
+            for i in range(num_steps)
+        ]
+
+        retry_step = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": "retry_step_1",
+            }
+        ]
+
+        triggered = hook.add_job_flow_steps(job_flow_id=job_flow_id, 
steps=steps)
+
+        triggered_steps = 
hook._get_list_of_steps_already_triggered(job_flow_id, ["PENDING", "RUNNING"])
+        assert len(triggered_steps) == num_steps == len(triggered)
+
+        cancel_steps = hook._cancel_list_of_steps_already_triggered(
+            steps + retry_step, job_flow_id, ["PENDING", "RUNNING"]
+        )
+
+        assert len(cancel_steps) == num_steps
+
+        with pytest.raises(NotImplementedError):
+            response = hook.send_cancel_steps(
+                steps_states=["PENDING", "RUNNING"],
+                emr_cluster_id=job_flow_id,
+                cancellation_option="SEND_INTERRUPT",
+                steps=steps + retry_step,
+            )
+
+            assert response
+
+        # assert set([status['Status'] for status in 
response['CancelStepsInfoList'][0]]) \
+        #        == {'SUBMITTED'} or None
+        #
+        # assert [step['StepId'] for step in 
response['CancelStepsInfoList'][0] if
+        #         step['Status'] in ['SUBMITTED']] == [step_id for step_name, 
step_id in cancel_steps]

Review Comment:
   Should be removed.



##########
tests/providers/amazon/aws/hooks/test_emr.py:
##########
@@ -190,3 +190,100 @@ def test_get_cluster_id_by_name(self):
         no_match = hook.get_cluster_id_by_name("foo", ["RUNNING", "WAITING", 
"BOOTSTRAPPING"])
 
         assert no_match is None
+
+    @mock_emr
+    def test_send_cancel_steps_first_invocation(self):
+        """
+        Test that we can resolve cluster id by cluster name.
+        """
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default")
+
+        job_flow = hook.create_job_flow(
+            {"Name": "test_cluster", "Instances": 
{"KeepJobFlowAliveWhenNoSteps": True}}
+        )
+
+        job_flow_id = job_flow["JobFlowId"]
+
+        step = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": "step_1",
+            }
+        ]
+
+        did_not_execute_response = hook.send_cancel_steps(
+            steps_states=["PENDING", "RUNNING"],
+            emr_cluster_id=job_flow_id,
+            cancellation_option="SEND_INTERRUPT",
+            steps=step,
+        )
+
+        assert did_not_execute_response is None
+
+    @mock_emr
+    @pytest.mark.parametrize("num_steps", [1, 2, 3, 4])
+    def test_send_cancel_steps_on_pre_existing_step_name(self, num_steps):
+        """
+        Test that we can resolve cluster id by cluster name.
+        """
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="emr_default")
+
+        job_flow = hook.create_job_flow(
+            {"Name": "test_cluster", "Instances": 
{"KeepJobFlowAliveWhenNoSteps": True}}
+        )
+
+        job_flow_id = job_flow["JobFlowId"]
+
+        steps = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": f"step_{i}",
+            }
+            for i in range(num_steps)
+        ]
+
+        retry_step = [
+            {
+                "ActionOnFailure": "test_step",
+                "HadoopJarStep": {
+                    "Args": ["test args"],
+                    "Jar": "test.jar",
+                },
+                "Name": "retry_step_1",
+            }
+        ]
+
+        triggered = hook.add_job_flow_steps(job_flow_id=job_flow_id, 
steps=steps)
+
+        triggered_steps = 
hook._get_list_of_steps_already_triggered(job_flow_id, ["PENDING", "RUNNING"])
+        assert len(triggered_steps) == num_steps == len(triggered)
+
+        cancel_steps = hook._cancel_list_of_steps_already_triggered(
+            steps + retry_step, job_flow_id, ["PENDING", "RUNNING"]
+        )
+
+        assert len(cancel_steps) == num_steps
+
+        with pytest.raises(NotImplementedError):

Review Comment:
   This looks odd to me, why NotImplementedError?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to