This is an automated email from the ASF dual-hosted git repository. yasith pushed a commit to branch sdk-batch-jobs in repository https://gitbox.apache.org/repos/asf/airavata.git
commit e88f0ee3db597eedc43e1f5be9e4d8b8923dc241 Author: yasithdev <[email protected]> AuthorDate: Thu Jul 31 05:28:08 2025 -0500 implement missing functions for batch sdk --- .../airavata_experiments/airavata.py | 67 +++++++++++++++------- .../airavata_experiments/plan.py | 16 +++--- .../airavata_experiments/runtime.py | 59 ++++++++++++++++--- .../airavata_experiments/task.py | 10 +++- 4 files changed, 116 insertions(+), 36 deletions(-) diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py index 05e91ab494..362dbdcf55 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py @@ -36,6 +36,7 @@ from airavata.model.experiment.ttypes import ExperimentModel, ExperimentType, Us from airavata.model.scheduling.ttypes import ComputationalResourceSchedulingModel from airavata.model.data.replica.ttypes import DataProductModel, DataProductType, DataReplicaLocationModel, ReplicaLocationCategory from airavata.model.appcatalog.groupresourceprofile.ttypes import GroupResourceProfile +from airavata.model.status.ttypes import JobStatus, JobState, ExperimentStatus, ExperimentState warnings.filterwarnings("ignore", category=DeprecationWarning) logger = logging.getLogger("airavata_sdk.clients") @@ -446,7 +447,36 @@ class AiravataOperator: assert process_id is not None, f"Expected process_id, got {process_id}" url_path = os.path.join(process_id, remote_file) filemgr_svc_download_url = f"{self.filemgr_svc_url()}/download/live/{url_path}" - + + def execute_cmd(self, agent_ref: str, cmd: str) -> bytes: + """ + Execute a command on a remote directory of a storage resource + TODO add data_svc fallback + + Return Path: /{project_name}/{experiment_name} + + """ + res = requests.post(f"{self.connection_svc_url()}/agent/execute/shell", json={ + "agentId": agent_ref, + "envName": agent_ref, + "workingDir": ".", + "arguments": ["sh", "-c", f"{cmd} | base64 -w0"] + }) + data = res.json() + if data["error"] is not None: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"{self.connection_svc_url()}/agent/execute/shell/{exc_id}") + data = res.json() + if data["executed"]: + content = data["responseString"] + import base64 + content = base64.b64decode(content) + return content + time.sleep(1) + def cat_file(self, process_id: str, agent_ref: str, sr_host: str, remote_file: str, remote_dir: str) -> bytes: """ Download files from a remote directory of a storage resource to a local directory @@ -674,14 +704,15 @@ class AiravataOperator: # wait until task begins, then get job id print(f"[AV] Experiment {experiment_name} WAITING until task begins...") job_id = job_state = None - while job_state is None: + while job_id in [None, "N/A"]: try: job_id, job_state = self.get_task_status(ex_id) except: time.sleep(2) else: time.sleep(2) - print(f"[AV] Experiment {experiment_name} - Task {job_state} with id: {job_id}") + assert job_state is not None, f"Job state is None for job id: {job_id}" + print(f"[AV] Experiment {experiment_name} - Task {job_state.name} with id: {job_id}") return LaunchState( experiment_id=ex_id, @@ -692,14 +723,12 @@ class AiravataOperator: sr_host=storage.hostName, ) - def get_experiment_status(self, experiment_id: str) -> Literal['CREATED', 'VALIDATED', 'SCHEDULED', 'LAUNCHED', 'EXECUTING', 'CANCELING', 'CANCELED', 'COMPLETED', 'FAILED']: - states = ["CREATED", "VALIDATED", "SCHEDULED", "LAUNCHED", "EXECUTING", "CANCELING", "CANCELED", "COMPLETED", "FAILED"] + def get_experiment_status(self, experiment_id: str) -> ExperimentState: status = self.api_server_client.get_experiment_status(self.airavata_token, experiment_id) - state = status.state.name - if state in states: - return state - else: - return "FAILED" + if status is None: + return ExperimentState.CREATED + assert isinstance(status, ExperimentStatus) + return status.state def stop_experiment(self, experiment_id: str): status = self.api_server_client.terminate_experiment( @@ -794,16 +823,12 @@ class AiravataOperator: Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, group="Default"), ] - def get_task_status(self, experiment_id: str) -> tuple[str, Literal["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] | None]: - states = ["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] - job_details: dict = self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # type: ignore - print(f"[av] Job details: {job_details}") + def get_task_status(self, experiment_id: str) -> tuple[str, JobState]: + job_details: dict[str, JobStatus] = self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # type: ignore job_id = job_state = None - # get the most recent job id and state - for job_id, v in job_details.items(): # type: ignore - if v.reason in states: - job_state = v.reason - else: - job_state = states[int(v.jobState)] - return job_id or "N/A", job_state # type: ignore + for job_id, v in job_details.items(): + job_state = v.jobState + return job_id or "N/A", job_state or JobState.UNKNOWN + + JobState = JobState diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py index 8560620616..7f35e3c6d6 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/plan.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/plan.py @@ -66,11 +66,13 @@ class Plan(pydantic.BaseModel): statuses.append(task.status()) return statuses - def __stage_stop__(self) -> None: - print("Stopping task(s)...") - for task in self.tasks: - task.stop() - print("Task(s) stopped.") + def __stage_stop__(self, runs: list[int] = []) -> None: + runs = runs if len(runs) > 0 else list(range(len(self.tasks))) + print(f"Stopping task(s): {runs}") + for i, task in enumerate(self.tasks): + if i in runs: + task.stop() + print(f"Task(s) stopped: {runs}") def __stage_fetch__(self, local_dir: str) -> list[list[str]]: print("Fetching results...") @@ -119,8 +121,8 @@ class Plan(pydantic.BaseModel): assert os.path.isdir(local_dir) self.__stage_fetch__(local_dir) - def stop(self) -> None: - self.__stage_stop__() + def stop(self, runs: list[int] = []) -> None: + self.__stage_stop__(runs) self.save() def export(self, filename: str) -> None: diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py index 00f039f459..054f94ec82 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/runtime.py @@ -16,7 +16,7 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, Literal from pathlib import Path import os @@ -24,6 +24,28 @@ import pydantic # from .task import Task Task = Any +States = Literal[ + # Experiment States + 'CREATED', + 'VALIDATED', + 'SCHEDULED', + 'LAUNCHED', + 'EXECUTING', + 'CANCELING', + 'CANCELED', + 'COMPLETED', + 'FAILED', + # Job States + 'SUBMITTED', + 'QUEUED', + 'ACTIVE', + 'COMPLETE', + 'CANCELED', + 'FAILED', + 'SUSPENDED', + 'UNKNOWN', + 'NON_CRITICAL_FAIL', +] class Runtime(abc.ABC, pydantic.BaseModel): @@ -36,6 +58,9 @@ class Runtime(abc.ABC, pydantic.BaseModel): @abc.abstractmethod def execute_py(self, libraries: list[str], code: str, task: Task) -> None: ... + @abc.abstractmethod + def execute_cmd(self, cmd: str, task: Task) -> bytes: ... + @abc.abstractmethod def status(self, task: Task) -> tuple[str, str]: ... @@ -87,6 +112,9 @@ class Mock(Runtime): task.agent_ref = str(uuid.uuid4()) task.ref = str(uuid.uuid4()) + def execute_cmd(self, cmd: str, task: Task) -> bytes: + return b"" + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: pass @@ -158,6 +186,23 @@ class Remote(Runtime): except Exception as e: print(f"[Remote] Failed to launch experiment: {repr(e)}") raise e + + def execute_cmd(self, cmd: str, task: Task) -> bytes: + assert task.ref is not None + assert task.agent_ref is not None + assert task.pid is not None + assert task.sr_host is not None + assert task.workdir is not None + + from .airavata import AiravataOperator + av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) + try: + result = av.execute_cmd(task.agent_ref, cmd) + return result + except Exception as e: + print(f"[Remote] Failed to execute command: {repr(e)}") + return b"" + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: assert task.ref is not None @@ -169,7 +214,7 @@ class Remote(Runtime): result = av.execute_py(task.project, libraries, code, task.agent_ref, task.pid, task.runtime.args) print(result) - def status(self, task: Task) -> tuple[str, str]: + def status(self, task: Task) -> tuple[str, States]: assert task.ref is not None assert task.agent_ref is not None @@ -177,10 +222,10 @@ class Remote(Runtime): av = AiravataOperator(os.environ['CS_ACCESS_TOKEN']) # prioritize job state, fallback to experiment state job_id, job_state = av.get_task_status(task.ref) - if not job_state or job_state == "UN_SUBMITTED": - return job_id, av.get_experiment_status(task.ref) + if job_state in [AiravataOperator.JobState.UNKNOWN, AiravataOperator.JobState.NON_CRITICAL_FAIL]: + return job_id, av.get_experiment_status(task.ref).name else: - return job_id, job_state + return job_id, job_state.name def signal(self, signal: str, task: Task) -> None: assert task.ref is not None @@ -259,5 +304,5 @@ def find_runtimes( out_runtimes.append(r) return out_runtimes -def is_terminal_state(x): - return x in ["CANCELED", "COMPLETED", "FAILED"] \ No newline at end of file +def is_terminal_state(x: States) -> bool: + return x in ["CANCELED", "COMPLETE", "COMPLETED", "FAILED"] \ No newline at end of file diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/task.py b/dev-tools/airavata-python-sdk/airavata_experiments/task.py index bcda796518..3700bf2f66 100644 --- a/dev-tools/airavata-python-sdk/airavata_experiments/task.py +++ b/dev-tools/airavata-python-sdk/airavata_experiments/task.py @@ -72,7 +72,11 @@ class Task(pydantic.BaseModel): assert self.ref is not None from pathlib import Path Path(local_dir).mkdir(parents=True, exist_ok=True) - return self.runtime.download(file, local_dir, self) + try: + return self.runtime.download(file, local_dir, self) + except Exception as e: + print(f"[Remote] Failed to download file: {repr(e)}") + return "" def download_all(self, local_dir: str) -> list[str]: assert self.ref is not None @@ -92,6 +96,10 @@ class Task(pydantic.BaseModel): def cat(self, file: str) -> bytes: assert self.ref is not None return self.runtime.cat(file, self) + + def exec(self, cmd: str) -> bytes: + assert self.ref is not None + return self.runtime.execute_cmd(cmd, self) def stop(self) -> None: assert self.ref is not None
