hussein-awala commented on code in PR #34071:
URL: https://github.com/apache/airflow/pull/34071#discussion_r1316248117


##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -114,6 +116,53 @@ def from_json(cls, data: str) -> RunState:
         return RunState(**json.loads(data))
 
 
+class ClusterState:

Review Comment:
   I think we can move it to a new module, wdyt?



##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -114,6 +116,53 @@ def from_json(cls, data: str) -> RunState:
         return RunState(**json.loads(data))
 
 
+class ClusterState:
+    """Utility class for the cluster state concept of Databricks cluster."""
+
+    CLUSTER_LIFE_CYCLE_STATES = [
+        "PENDING",
+        "RUNNING",
+        "RESTARTING",
+        "RESIZING",
+        "TERMINATING",
+        "TERMINATED",
+        "ERROR",
+        "UNKNOWN",
+    ]
+
+    def __init__(self, state: str = "", state_message: str = "", *args, 
**kwargs) -> None:
+        self.state = state
+        self.state_message = state_message
+
+    @property
+    def is_terminal(self) -> bool:
+        """True if the current state is a terminal state."""
+        if self.state not in self.CLUSTER_LIFE_CYCLE_STATES:
+            raise AirflowException(f"Unexpected cluster life cycle state: 
{self.state}")

Review Comment:
   could you move this to `__init__`?



##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -388,6 +437,32 @@ def repair_run(self, json: dict) -> None:
         """
         self._do_api_call(REPAIR_RUN_ENDPOINT, json)
 
+    def get_cluster_state(self, cluster_id: str) -> ClusterState:
+        """
+        Retrieves run state of the cluster.
+
+        :param cluster_id: id of the cluster
+        :return: state of the cluster
+        """
+        json = {"cluster_id": cluster_id}
+        response = self._do_api_call(GET_CLUSTER_ENDPOINT, json)
+        state = response["state"]
+        state_message = response["state_message"]
+        return ClusterState(state, state_message)
+
+    async def a_get_cluster_state(self, cluster_id: str) -> ClusterState:
+        """
+        Async version of `get_cluster_state`.
+
+        :param cluster_id: id of the cluster
+        :return: state of the cluster
+        """
+        json = {"cluster_id": cluster_id}
+        response = await self._a_do_api_call(GET_CLUSTER_ENDPOINT, json)
+        state = response["state"]
+        state_message = response["state_message"]
+        return ClusterState(state, state_message)

Review Comment:
   same here, you can use `from_json`



##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -404,6 +479,40 @@ def start_cluster(self, json: dict) -> None:
         """
         self._do_api_call(START_CLUSTER_ENDPOINT, json)
 
+    def activate_cluster(self, json: dict, polling: int, timeout: int | None = 
None) -> None:
+        """
+        Start the cluster, and wait for it to be ready.
+
+        :param json: json dictionary containing cluster specification.
+        :param polling: polling interval in seconds.
+        :param timeout: timeout in seconds. -1 means no timeout.
+        """
+        cluster_id = json['cluster_id']

Review Comment:
   We always use double quotes in the project
   ```suggestion
           cluster_id = json["cluster_id"]
   ```



##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -388,6 +437,32 @@ def repair_run(self, json: dict) -> None:
         """
         self._do_api_call(REPAIR_RUN_ENDPOINT, json)
 
+    def get_cluster_state(self, cluster_id: str) -> ClusterState:
+        """
+        Retrieves run state of the cluster.
+
+        :param cluster_id: id of the cluster
+        :return: state of the cluster
+        """
+        json = {"cluster_id": cluster_id}
+        response = self._do_api_call(GET_CLUSTER_ENDPOINT, json)
+        state = response["state"]
+        state_message = response["state_message"]
+        return ClusterState(state, state_message)
+
+    async def a_get_cluster_state(self, cluster_id: str) -> ClusterState:

Review Comment:
   wdyt about this name?
   ```suggestion
       async def async_get_cluster_state(self, cluster_id: str) -> ClusterState:
   ```
   
   I know that we have `_a_do_api_call`, but since this will be a public 
method, we can improve its name and make it clearer.



##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -114,6 +116,53 @@ def from_json(cls, data: str) -> RunState:
         return RunState(**json.loads(data))
 
 
+class ClusterState:
+    """Utility class for the cluster state concept of Databricks cluster."""
+
+    CLUSTER_LIFE_CYCLE_STATES = [
+        "PENDING",
+        "RUNNING",
+        "RESTARTING",
+        "RESIZING",
+        "TERMINATING",
+        "TERMINATED",
+        "ERROR",
+        "UNKNOWN",
+    ]
+
+    def __init__(self, state: str = "", state_message: str = "", *args, 
**kwargs) -> None:
+        self.state = state
+        self.state_message = state_message
+
+    @property
+    def is_terminal(self) -> bool:
+        """True if the current state is a terminal state."""
+        if self.state not in self.CLUSTER_LIFE_CYCLE_STATES:
+            raise AirflowException(f"Unexpected cluster life cycle state: 
{self.state}")
+        return self.state in ("TERMINATING", "TERMINATED", "ERROR", "UNKNOWN")
+
+    @property
+    def is_running(self) -> bool:
+        """True if the current state is running."""
+        return self.state in ("RUNNING", "RESIZING")
+
+    def __eq__(self, other) -> bool:
+        return (
+            self.state == other.state and \
+            self.state_message == other.state_message
+        )
+
+    def __repr__(self) -> str:
+        return str(self.__dict__)
+
+    def to_json(self) -> str:
+        return json.dumps(self.__dict__)
+
+    @classmethod
+    def from_json(cls, data: str) -> RunState:
+        return RunState(**json.loads(data))

Review Comment:
   ```suggestion
       @classmethod
       def from_json(cls, data: str) -> ClusterState:
           return ClusterState(**json.loads(data))
   ```



##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -388,6 +437,32 @@ def repair_run(self, json: dict) -> None:
         """
         self._do_api_call(REPAIR_RUN_ENDPOINT, json)
 
+    def get_cluster_state(self, cluster_id: str) -> ClusterState:
+        """
+        Retrieves run state of the cluster.
+
+        :param cluster_id: id of the cluster
+        :return: state of the cluster
+        """
+        json = {"cluster_id": cluster_id}
+        response = self._do_api_call(GET_CLUSTER_ENDPOINT, json)
+        state = response["state"]
+        state_message = response["state_message"]
+        return ClusterState(state, state_message)

Review Comment:
   what about creating it via `from_json`?



##########
airflow/providers/databricks/hooks/databricks.py:
##########
@@ -404,6 +479,40 @@ def start_cluster(self, json: dict) -> None:
         """
         self._do_api_call(START_CLUSTER_ENDPOINT, json)
 
+    def activate_cluster(self, json: dict, polling: int, timeout: int | None = 
None) -> None:
+        """
+        Start the cluster, and wait for it to be ready.
+
+        :param json: json dictionary containing cluster specification.
+        :param polling: polling interval in seconds.
+        :param timeout: timeout in seconds. -1 means no timeout.
+        """
+        cluster_id = json['cluster_id']
+
+        api_called = False
+        elapsed_time = 0
+
+        while True:
+            run_state = self.get_cluster_state(cluster_id)
+
+            if run_state.is_running:
+                return
+            elif run_state.is_terminal:
+                if not api_called:
+                    self.start_cluster(json)
+                    api_called = True
+                else:
+                    raise AirflowException(
+                        f"Cluster {cluster_id} start failed with 
'{run_state.state}' state: {run_state.state_message}"
+                    )
+
+            # wait for cluster to start
+            time.sleep(polling)
+
+            elapsed_time += polling
+            if timeout and elapsed_time <= timeout:
+                raise AirflowException(f"Cluster {cluster_id} start timed out 
after {timeout} seconds")

Review Comment:
   you assume that the time in each iteration is `polling`, but there is the 
time needed to call the API and get the state; you can replace this timeout by 
`max_attempts` or implement a real timeout based on a start_date defined before 
entering the loop.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to